From 547378f5828a9dea80e3cb68c76dcab06cca2666 Mon Sep 17 00:00:00 2001 From: hedongdong Date: Thu, 17 Jul 2025 20:47:36 +0800 Subject: [PATCH 1/7] merge tensor-storage-refactor --- .gitmodules | 2 +- mindspore-lite/cmake/ccsrc_converter.cmake | 2 +- mindspore-lite/src/extendrt/CMakeLists.txt | 1 + .../cxx_api/llm_engine/llm_engine_plugin.cc | 3 +- .../src/extendrt/cxx_api/model/model_impl.cc | 8 +- .../delegate/ascend_ge/ge_graph_executor.cc | 5 +- .../graph_executor/litert/graph_executor.cc | 7 +- .../tensorrt/tensorrt_graph_executor.cc | 9 +- .../delegate/tensorrt/tensorrt_subgraph.cc | 32 +-- .../src/extendrt/lite_device_address.cc | 230 ++++++++++++++++++ .../src/extendrt/lite_device_address.h | 65 +++++ .../mindir_model/mindir_model_util.cc | 5 +- .../extendrt/session/ascend_native_session.cc | 3 +- .../src/extendrt/session/default_session.cc | 3 +- .../src/extendrt/session/single_op_session.cc | 18 +- .../src/extendrt/utils/func_graph_utils.cc | 7 +- .../src/extendrt/utils/tensor_utils.cc | 8 +- .../src/extendrt/utils/tensor_utils.h | 9 +- .../test/common/import_from_meta_graphT.cc | 3 +- .../tools/common/custom_ascend_utils.cc | 5 +- mindspore-lite/tools/common/tensor_util.cc | 19 +- mindspore-lite/tools/converter/CMakeLists.txt | 7 + .../cxx_api/graph/ascend/ascend_graph_impl.cc | 6 +- .../cxx_api/graph/gpu/gpu_graph_impl.cc | 14 +- .../cxx_api_lite/cxx_api/graph/graph_impl.h | 3 +- .../cxx_api/model/acl/acl_model_multi.cc | 3 +- .../adapter/acl/mapper/squeeze_mapper.cc | 2 +- .../tools/converter/export_model.cc | 7 +- .../tools/converter/import/mindir_adjust.cc | 7 +- .../converter/offline_packing_optimizer.cc | 2 +- .../parser/onnx/onnx_constant_parser.cc | 3 +- .../parser/onnx/onnx_model_parser.cc | 3 +- .../converter/parser/onnx/onnx_node_parser.cc | 13 +- .../converter/parser/tf/tf_model_parser.cc | 3 +- .../quantizer/cluster_quantization.cc | 2 +- .../converter/quantizer/gptq_quantizer.cc | 8 +- .../converter/quantizer/huffman_encode.cc | 4 +- .../quant_helper/transform_uint8_pass.cc | 2 +- .../converter/quantizer/quantize_util.cc | 4 +- .../converter/quantizer/split_shared_bias.cc | 7 +- .../converter/quantizer/tensor_compressor.cc | 6 +- .../converter/quantizer/tensor_compressor.h | 6 +- .../converter/format_recognition.cc | 2 +- .../converter/preprocess_weight.cc | 9 +- .../tools/lite_exporter/fetch_content.cc | 52 ++-- .../mindir_exporter/mindir_serializer.cc | 4 +- .../tools/optimizer/common/format_utils.cc | 3 +- .../tools/optimizer/common/gllo_utils.cc | 3 +- .../tools/optimizer/const_fold/fold_utils.cc | 3 +- .../optimizer/fusion/batchmatmul_fusion.cc | 3 +- .../optimizer/fusion/decoder_layer_fusion.cc | 2 +- .../optimizer/fusion/encoder_layer_fusion.cc | 4 +- .../fusion/kv_cache_mgr_one_branch_fusion.cc | 2 +- .../fusion/multi_head_attention_fusion.cc | 7 +- .../fusion/reduce_same_op_in_horizon.cc | 5 +- .../fusion/tf_bidirection_gru_fusion.cc | 3 +- .../optimizer/graph/grouped_matmul_op_pass.cc | 3 +- .../graph/input_and_output_variable_pass.cc | 3 +- .../optimizer/graph/lite_tensor_extractor.cc | 2 +- .../optimizer/graph/miniaturization_pass.cc | 18 +- .../tools/optimizer/graph/node_infershape.cc | 3 +- .../optimizer/graph/output_variable_pass.cc | 3 +- .../tools/optimizer/graph/scalar_op_pass.cc | 3 +- .../parallel/depthwise_conv2d_info.cc | 3 +- 64 files changed, 533 insertions(+), 163 deletions(-) create mode 100644 mindspore-lite/src/extendrt/lite_device_address.cc create mode 100644 mindspore-lite/src/extendrt/lite_device_address.h diff --git a/.gitmodules b/.gitmodules index d7f1a58b..16b5a0ee 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,4 +2,4 @@ path = mindspore url = https://gitee.com/mindspore/mindspore.git # shallow = true - branch = r2.7.rc1 \ No newline at end of file + branch = tensor-storage-refactor diff --git a/mindspore-lite/cmake/ccsrc_converter.cmake b/mindspore-lite/cmake/ccsrc_converter.cmake index 51815b7c..22480151 100644 --- a/mindspore-lite/cmake/ccsrc_converter.cmake +++ b/mindspore-lite/cmake/ccsrc_converter.cmake @@ -24,6 +24,7 @@ if(MSLITE_ENABLE_CONVERTER) ${OPS_DIR}/kernel/common/kernel_factory.cc ${OPS_DIR}/kernel/common/format_utils.cc ${CCSRC_DIR}/utils/convert_utils.cc + ${CCSRC_DIR}/runtime/device/res_manager/utils/convert_tensor_utils.cc ) if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) @@ -42,7 +43,6 @@ if(MSLITE_ENABLE_CONVERTER) ${OPS_DIR}/kernel/common/kernel_build_info.cc ${OPS_DIR}/kernel/common/oplib/oplib.cc ${CCSRC_DIR}/kernel/kernel_info.cc - ${CCSRC_DIR}/runtime/device/res_manager/utils/convert_tensor_utils.cc ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc ${CCSRC_DIR}/runtime/hardware/device_context_manager.cc diff --git a/mindspore-lite/src/extendrt/CMakeLists.txt b/mindspore-lite/src/extendrt/CMakeLists.txt index afd29018..cfe55e2f 100644 --- a/mindspore-lite/src/extendrt/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/CMakeLists.txt @@ -54,6 +54,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) ${CMAKE_CURRENT_SOURCE_DIR}/session/factory.cc ${CMAKE_CURRENT_SOURCE_DIR}/memory_offload/infer_strategy_builder.cc ${CMAKE_CURRENT_SOURCE_DIR}/infer_device_address.cc + ${CMAKE_CURRENT_SOURCE_DIR}/lite_device_address.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/kernel_build_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/tensor_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/runtime_utils.cc diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc index a18959d6..38f42b93 100644 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc +++ b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc @@ -24,6 +24,7 @@ #include "ge/llm_engine.h" #include "external/ge_common/ge_api_error_codes.h" #include "ge/llm_error_codes.h" +#include "ir/device_address_maker.h" namespace mindspore { struct LLMModelInfo { @@ -958,7 +959,7 @@ MSTensor LLMEnginePlugin::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr) { } auto tensor_data = std::make_shared(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter); auto type_id = device::ascend::TransformUtil::ConvertGeDataType(ge_tensor_desc.GetDataType()); - auto tensor = std::make_shared(type_id, me_shape, tensor_data); + auto tensor = std::make_shared(type_id, me_shape, MakeDeviceAddress(type_id, me_shape, tensor_data)); auto tensor_impl = std::make_shared(tensor); return MSTensor(tensor_impl); } diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc index 706faa57..e8ad77ec 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc @@ -48,6 +48,7 @@ #include "include/api/model_group.h" #include "src/common/common.h" +#include "ir/tensor_api.h" namespace mindspore { namespace { const char *const kExecutionPlan = "execution_plan"; @@ -73,7 +74,7 @@ FuncGraphPtr CreateFuncGraphFromDataFlow(const void *model_data, size_t data_siz auto type_ptr = TypeIdToType(kNumberTypeUInt8); MS_CHECK_TRUE_RET(type_ptr != nullptr, nullptr); ShapeVector shape = {static_cast(data_size)}; - auto param_tensor = std::make_shared(kNumberTypeUInt8, shape); + auto param_tensor = tensor::empty(kNumberTypeUInt8, shape, device::DeviceType::kCPU); MS_CHECK_TRUE_RET(param_tensor != nullptr, nullptr); if (param_tensor->Size() != data_size) { MS_LOG(ERROR) << "The data size of param value is not equal to the data size: " << data_size; @@ -886,8 +887,9 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vectorGetMutablePtr() != + graph_outputs[i].device_address()->GetMutablePtr()) { output_remain = false; break; } diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc index 3ccf1c55..cd08c146 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc @@ -26,6 +26,7 @@ #include "backend/ge_backend/graph_ir/utils.h" #include "common/device_type.h" #include "include/common/utils/ms_device_shape_transfer.h" +#include "ir/device_address_maker.h" #include "src/common/common.h" #include "src/common/file_utils.h" #include "cxx_api/acl_utils.h" @@ -1406,7 +1407,7 @@ bool GeGraphExecutor::GetOneRealInputs(const FuncGraphPtr &anf_graph, std::vecto MS_LOG(ERROR) << "Cannot find input " << input_name << " in input_shape " << input_shape_str; return false; } - input = std::make_shared(input->data_type(), it->second); + input = tensor::empty(input->data_type(), it->second, device::DeviceType::kCPU); } else if (GeDynamicUtils::IsDynamicInputShapes({input->shape_c()})) { MS_LOG(ERROR) << "Input " << i << " is dynamic shape " << input->shape_c() << ", but there is no input shape specified in AscendDeviceInfo or config file"; @@ -1870,7 +1871,7 @@ tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor return nullptr; } auto tensor_data = std::make_shared(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter); - return std::make_shared(type_id, me_shape, tensor_data); + return std::make_shared(type_id, me_shape, MakeDeviceAddress(type_id, me_shape, tensor_data)); } std::vector GeGraphExecutor::GetOutputInfos(uint32_t graph_id) { diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc index d8f4f210..f5a40a56 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc @@ -25,6 +25,7 @@ #include "src/litert/lite_model.h" #include "src/litert/cpu_info.h" #include "include/errorcode.h" +#include "ir/device_address_maker.h" #include "flatbuffers/flatbuffers.h" #include "extendrt/mock/lite_runtime/converters.h" #include "extendrt/delegate/factory.h" @@ -292,7 +293,8 @@ std::vector LiteRTGraphExecutor::GetInputInfos(uint32_t) { std::vector lite_shape; std::transform(shape.begin(), shape.end(), std::back_inserter(lite_shape), [](int c) { return static_cast(c); }); - auto tmp = tensor::Tensor(type_id, lite_shape); + auto tmp = + tensor::Tensor(type_id, lite_shape, MakeDeviceAddress(type_id, lite_shape, true, device::DeviceType::kCPU)); tmp.set_name(inputs[i]->tensor_name()); input_tensors.push_back(tmp); } @@ -304,7 +306,8 @@ std::vector LiteRTGraphExecutor::GetOutputInfos(uint32_t) { std::vector output_tensors; for (size_t i = 0; i < outputs.size(); ++i) { auto type_id = static_cast(outputs[i].DataType()); - auto tmp = tensor::Tensor(type_id, outputs[i].Shape()); + auto tmp = tensor::Tensor(type_id, outputs[i].Shape(), + MakeDeviceAddress(type_id, outputs[i].Shape(), true, device::DeviceType::kCPU)); tmp.set_name(outputs[i].Name()); output_tensors.push_back(tmp); } diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc index 83c8414a..c1b507b4 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc @@ -34,6 +34,7 @@ #include "src/extendrt/utils/func_graph_utils.h" #include "src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h" #include "infer/custom.h" +#include "ir/device_address_maker.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" @@ -79,7 +80,7 @@ tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node) { } if (value->isa()) { auto tensor = value->cast(); - if (tensor == nullptr || tensor->data().const_data() == nullptr) { + if (tensor == nullptr || tensor->unsafe_data() == nullptr) { return nullptr; } return tensor; @@ -643,7 +644,8 @@ std::vector TensorRTExecutor::GetInputInfos(uint32_t) { for (auto &tensor_info : inputs_) { auto type_id = static_cast(tensor_info.DataType()); auto shape = tensor_info.Shape(); - tensors.push_back(tensor::Tensor(type_id, shape)); + tensors.push_back( + tensor::Tensor(type_id, shape, MakeDeviceAddress(type_id, shape, true, device::DeviceType::kCPU))); } return tensors; } @@ -653,7 +655,8 @@ std::vector TensorRTExecutor::GetOutputInfos(uint32_t) { for (auto &tensor_info : outputs_) { auto type_id = static_cast(tensor_info.DataType()); auto shape = tensor_info.Shape(); - tensors.push_back(tensor::Tensor(type_id, shape)); + tensors.push_back( + tensor::Tensor(type_id, shape, MakeDeviceAddress(type_id, shape, true, device::DeviceType::kCPU))); } return tensors; } diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc index 6982b6b7..8917a6d4 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc @@ -735,19 +735,18 @@ int TensorRTSubGraph::VSLPreExectute(const std::vector &inputs, const int pos_ids_idx = Num2 + is_expert_ids; const int current_idx_idx = Num3 + is_expert_ids; if (i == input_ids_idx || i == expert_ids_idx || i == pos_ids_idx) { - int *in_ptr = static_cast(inputs[i].data_ptr()->data()); + int *in_ptr = static_cast(inputs[i].data_c()); int batch = inputs[trt_in_tensor_name_.size() - Num1].ElementsNum(); int seq = inputs[0].ElementsNum() / batch; int export_num = (i != expert_ids_idx) ? Num1 : inputs[i].ElementsNum() / (batch * seq); - bool incremental_mode = - (static_cast(inputs[pos_ids_idx].data().const_data())[0] != 0) ? true : false; + bool incremental_mode = (static_cast(inputs[pos_ids_idx].unsafe_data())[0] != 0) ? true : false; size_t h_token = 0; for (int k = 0; k < batch; k++) { int actual_seq_len = (incremental_mode) ? Num1 - : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k] + Num1); - int batch_valid = static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k]; + : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].unsafe_data())[k] + Num1); + int batch_valid = static_cast(inputs[trt_in_tensor_name_.size() - Num1].unsafe_data())[k]; h_token += (batch_valid == -1) ? 0 : actual_seq_len; } for (int j = 0; j < export_num; j++) { @@ -756,10 +755,9 @@ int TensorRTSubGraph::VSLPreExectute(const std::vector &inputs, int actual_seq_len = (incremental_mode) ? Num1 - : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k] + Num1); + : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].unsafe_data())[k] + Num1); for (int l = 0; l < actual_seq_len; l++) { - in_ptr[j * h_token + h_token_idx + l] = - static_cast(inputs[i].data_ptr()->data())[j * batch * seq + k * seq + l]; + in_ptr[j * h_token + h_token_idx + l] = static_cast(inputs[i].data_c())[j * batch * seq + k * seq + l]; } h_token_idx += actual_seq_len; } @@ -788,12 +786,17 @@ int TensorRTSubGraph::PreExecute(const std::vector &inputs, cons if (ret != RET_OK) { return ret; } + auto hasDeviceData = [&](const tensor::Tensor &t) -> bool { + auto device_address = t.device_address(); + return device_address != nullptr && device_address->GetMutablePtr() != nullptr && + device_address->GetDeviceType() != device::DeviceType::kCPU; + }; + for (size_t i = 0; i < trt_in_tensor_name_.size(); i++) { auto trt_tensor_name = trt_in_tensor_name_[i]; void *device_ptr = nullptr; - auto input_device_address = inputs[i].device_address(); - if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) { - device_ptr = input_device_address->GetMutablePtr(); + if (hasDeviceData(inputs[i])) { + device_ptr = inputs[i].device_address()->GetMutablePtr(); } else { device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_tensor_name, inputs_[i].DataSize(), ConvertDataType(inputs_[i].DataType())); @@ -822,7 +825,7 @@ int TensorRTSubGraph::PreExecute(const std::vector &inputs, cons void *device_ptr = nullptr; if (outputs.size() > i) { auto &output = outputs[i]; - if (output.device_address() && output.device_address()->GetMutablePtr()) { + if (hasDeviceData(output)) { if (output.Size() < outputs_[i].DataSize()) { MS_LOG(ERROR) << "Specified output device data size " << output.Size() << " cannot less than execute output data size " << outputs_[i].DataSize() @@ -832,7 +835,7 @@ int TensorRTSubGraph::PreExecute(const std::vector &inputs, cons device_ptr = output.device_address()->GetMutablePtr(); } } - if (!device_ptr) { + if (device_ptr == nullptr) { device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_out_tensor_name, outputs_[i].DataSize(), ConvertDataType(outputs_[i].DataType())); if (device_ptr == nullptr) { @@ -864,7 +867,8 @@ int TensorRTSubGraph::PostExecute(std::vector *outputs, bool syn if (has_outputs) { auto &tensor = outputs->at(i); auto dst_device = tensor.device_address(); - if (dst_device == nullptr || dst_device->GetMutablePtr() == nullptr) { + if (dst_device == nullptr || dst_device->GetMutablePtr() == nullptr || + dst_device->GetDeviceType() == device::DeviceType::kCPU) { if (tensor.Size() < outputs_[i].DataSize()) { MS_LOG(ERROR) << "Specified output host data size " << tensor.Size() << " cannot less than execute output data size " << outputs_[i].DataSize() diff --git a/mindspore-lite/src/extendrt/lite_device_address.cc b/mindspore-lite/src/extendrt/lite_device_address.cc new file mode 100644 index 00000000..11db3820 --- /dev/null +++ b/mindspore-lite/src/extendrt/lite_device_address.cc @@ -0,0 +1,230 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/extendrt/lite_device_address.h" + +#include +#include +#include + +#include "ir/device_address_maker.h" +#include "runtime/device/res_manager/utils/convert_tensor_utils.h" +#include "utils/ms_context.h" + +namespace mindspore { +namespace runtime { +namespace test { +namespace { +DeviceAddressPtr CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, + TypeId type_id, const std::string &device_name, uint32_t device_id, + uint32_t stream_id, const UserDataPtr &user_data = nullptr) { + return std::make_shared(ptr, size, "falut", type_id, device_name, 0); +} +DeviceSyncPtr MakeTestDeviceAddress(TypeId data_type, const ShapeVector &shape, void *data_ptr, + DeviceAddressDeleter &&deleter) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + auto device_id = context->get_param(MS_CTX_DEVICE_ID); + auto data_size = SizeOf(shape) * abstract::TypeIdSize(data_type); + auto device_address = + CreateDeviceAddress(data_ptr, data_size, shape, Format::DEFAULT_FORMAT, data_type, "CPU", device_id, 0); + device_address->SetPointerRefCountDeleter(std::move(deleter)); + return device_address; +} + +const char device_name[] = "CPU"; +REGISTER_DEVICE_ADDRESS_MAKER(device::DeviceType::kCPU, [](TypeId data_type, const ShapeVector &shape, void *data_ptr, + DeviceAddressDeleter &&deleter) { + return MakeTestDeviceAddress(data_type, shape, data_ptr, std::move(deleter)); +}); + +// clang-format off +#define FOR_EACH_TYPE_BASE(M) \ + M(kNumberTypeBool, bool) \ + M(kNumberTypeUInt8, uint8_t) \ + M(kNumberTypeInt4, int8_t) \ + M(kNumberTypeInt8, int8_t) \ + M(kNumberTypeInt16, int16_t) \ + M(kNumberTypeInt32, int32_t) \ + M(kNumberTypeInt64, int64_t) \ + M(kNumberTypeUInt16, uint16_t) \ + M(kNumberTypeUInt32, uint32_t) \ + M(kNumberTypeUInt64, uint64_t) \ + M(kNumberTypeFloat16, float16) \ + M(kNumberTypeFloat32, float) \ + M(kNumberTypeFloat64, double) \ + M(kNumberTypeFloat8E4M3FN, float8_e4m3fn) \ + M(kNumberTypeFloat8E5M2, float8_e5m2) \ + M(kNumberTypeHiFloat8, hifloat8) \ + M(kNumberTypeComplex64, ComplexStorage) \ + M(kNumberTypeComplex128, ComplexStorage) + +#ifndef KERNEL_EXECUTOR_ANDROID +#define FOR_EACH_TYPE_EXTRA(M) M(kNumberTypeBFloat16, bfloat16) +#else +#define FOR_EACH_TYPE_EXTRA(M) +#endif + +#define FOR_EACH_TYPE(M) \ + FOR_EACH_TYPE_BASE(M) \ + FOR_EACH_TYPE_EXTRA(M) + +#define REGISTER_SIZE(address_type_id, address_type) { address_type_id, sizeof(address_type) }, + +static const std::unordered_map kTypeSizeMap = { + FOR_EACH_TYPE(REGISTER_SIZE) +}; + +size_t GetTypeSize(TypeId tid) { + return kTypeSizeMap.at(tid); +} + +template +using DstCopyFunc = void (*)(T *src_ptr, void *dst_ptr, size_t size); + +template +static const std::unordered_map> g_dst_copy_map = { +#define REGISTER_DST(dst_type_id, dst_type) \ + {dst_type_id, +[](T *src_ptr, void *dst_ptr, size_t size) { \ + auto buf = static_cast(dst_ptr); \ + return tensor::TransDataType(src_ptr, buf, size); \ + }}, + FOR_EACH_TYPE(REGISTER_DST) +#undef REGISTER_DST +}; + +template +void CopyData(T *src_ptr, size_t size, void *dst_ptr, TypeId dst_type_id) { + auto &m = g_dst_copy_map; + auto it = m.find(dst_type_id); + if (it == m.end()) { + MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported dst data type: " << dst_type_id << "."; + } + it->second(src_ptr, dst_ptr, size); +} + +using SrcCopyFunc = std::function; + +static const std::unordered_map g_src_copy_map = { +#define REGISTER_SRC(src_type_id, src_type) \ + {src_type_id, +[](void *src_ptr, void *dst_ptr, size_t size, TypeId dst_type_id) { \ + auto buf = static_cast(src_ptr); \ + return CopyData(buf, size, dst_ptr, dst_type_id); \ + }}, + FOR_EACH_TYPE(REGISTER_SRC) +#undef REGISTER_SRC +}; + +#undef FOR_EACH_TYPE +#undef FOR_EACH_TYPE_BASE +#undef FOR_EACH_TYPE_EXTRA +#undef REGISTER_SIZE +// clang-format on + +void CopyData(const DeviceAddress *src_device_address, const DeviceAddress *dst_device_address) { + MS_EXCEPTION_IF_NULL(src_device_address); + MS_EXCEPTION_IF_NULL(dst_device_address); + + TypeId src_type_id = src_device_address->type_id(); + TypeId dst_type_id = dst_device_address->type_id(); + auto src_size = src_device_address->GetSize() / GetTypeSize(src_type_id); + auto dst_size = dst_device_address->GetSize() / GetTypeSize(dst_type_id); + if (src_size != dst_size) { + MS_LOG(EXCEPTION) << "Not same shape in device address:" << src_device_address->ToString() + << " and:" << dst_device_address->ToString(); + } + + void *src_ptr = src_device_address->GetMutablePtr(); + void *dst_ptr = dst_device_address->GetMutablePtr(); + MS_EXCEPTION_IF_NULL(src_ptr); + MS_EXCEPTION_IF_NULL(dst_ptr); + + auto it = g_src_copy_map.find(src_type_id); + if (it == g_src_copy_map.end()) { + MS_LOG(EXCEPTION) << "Unsupported conversion from " << src_type_id << " to " << dst_type_id; + } + it->second(src_ptr, dst_ptr, src_size, dst_type_id); +} +} // namespace + +bool LiteAsyncCopy(const DeviceSyncPtr &dst_device_sync, const DeviceSyncPtr &src_device_sync, size_t stream_id, bool) { + const auto &dst_device_address = dynamic_cast(dst_device_sync.get()); + const auto &src_device_address = dynamic_cast(src_device_sync.get()); + MS_EXCEPTION_IF_NULL(dst_device_address); + MS_EXCEPTION_IF_NULL(src_device_address); + if (dst_device_address->GetSize() == 0 || src_device_address->GetSize() == 0) { + MS_LOG(INFO) << "No need sync for dst device address: " << dst_device_address->ToString() + << " and src device address: " << src_device_address->ToString(); + return true; + } + + if (dst_device_address->format() != src_device_address->format()) { + MS_LOG(ERROR) << "Format is different, src(format:" << src_device_address->format() + << "), dst(format:" << dst_device_address->format() << ") for device address:" << dst_device_address; + return false; + } + auto dst_ptr = dst_device_address->GetMutablePtr(); + auto src_ptr = src_device_address->GetMutablePtr(); + MS_EXCEPTION_IF_NULL(src_device_address->GetMutablePtr()); + MS_EXCEPTION_IF_NULL(dst_device_address->GetMutablePtr()); + if (dst_ptr == src_ptr) { + MS_LOG(DEBUG) << "host_ptr is equal to device ptr, request ignored."; + return true; + } + auto dst_type_id = dst_device_address->type_id(); + auto src_type_id = src_device_address->type_id(); + + if (src_type_id == dst_type_id) { + if (src_device_address->GetSize() > dst_device_address->GetSize()) { + MS_LOG(WARNING) << "Please check whether need sync data, src size: " << src_device_address->GetSize() + << ", dst size: " << dst_device_address->GetSize(); + return true; + } + auto ret_code = memcpy_s(dst_ptr, src_device_address->GetSize(), src_ptr, src_device_address->GetSize()); + // Return ERANGE when the copy size is larger than SECUREC_MEM_MAX_LEN. + if (ret_code == ERANGE) { + device::ConvertSameType(dst_device_address->GetMutablePtr(), src_device_address->GetMutablePtr(), + dst_device_address->GetSize(), src_type_id); + } else if (ret_code != EOK) { + MS_LOG(ERROR) << "Failed to copy tensor from device address:" << src_device_address->ToString() + << " to :" << dst_device_address->ToString(); + return false; + } else { + return true; + } + } + + MS_LOG(INFO) << "Types not match. src type: " << TypeIdLabel(src_type_id) + << ", dst type: " << TypeIdLabel(dst_type_id) << " device_address:" << dst_device_address << " !"; + CopyData(src_device_address, dst_device_address); + return true; +} + +bool LiteSyncCopy(const DeviceSyncPtr &dst_device_sync, const DeviceSyncPtr &src_device_sync, size_t stream_id) { + return LiteAsyncCopy(dst_device_sync, src_device_sync, stream_id, false); +} + +MS_REGISTER_HAL_COPY_FUNC(DeviceType::kCPU, + ([](const DeviceSyncPtr &dst_device_sync, const DeviceSyncPtr &src_device_sync, + size_t stream_id) { return LiteSyncCopy(dst_device_sync, src_device_sync, stream_id); }), + ([](const DeviceSyncPtr &dst_device_sync, const DeviceSyncPtr &src_device_sync, + size_t stream_id, + bool) { return LiteSyncCopy(dst_device_sync, src_device_sync, stream_id); }), + ([](void *dst, const void *src, uint64_t size, size_t stream_id) { return true; })); + +} // namespace test +} // namespace runtime +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/lite_device_address.h b/mindspore-lite/src/extendrt/lite_device_address.h new file mode 100644 index 00000000..4fffd2e4 --- /dev/null +++ b/mindspore-lite/src/extendrt/lite_device_address.h @@ -0,0 +1,65 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_LITE_DEVICE_ADDRESS_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_LITE_DEVICE_ADDRESS_H_ + +#include +#include +#include + +#include "common/device_address.h" + +namespace mindspore { +namespace runtime { +namespace test { +using device::DeviceAddress; +using device::DeviceAddressPtr; +using device::DeviceType; + +class TestDeviceAddress : public DeviceAddress { + public: + TestDeviceAddress() : DeviceAddress() {} + TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + TestDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const std::string &device_name, + uint32_t device_id) + : DeviceAddress(ptr, size, format, type_id, device_name, device_id) {} + ~TestDeviceAddress() {} + virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr, + bool sync_on_demand) const { + return true; + } + virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, + const std::string &format) const { + return true; + } + virtual void ClearDeviceMemory() {} + DeviceType GetDeviceType() const override { return DeviceType::kCPU; } + + void set_data(tensor::TensorDataPtr &&data) override { data_ = std::move(data); } + + const tensor::TensorDataPtr &data() const override { return data_; } + + bool has_data() const override { return data_ != nullptr; } + + private: + // the data for numpy object. + tensor::TensorDataPtr data_; +}; +} // namespace test +} // namespace runtime +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_EXTENDRT_LITE_DEVICE_ADDRESS_H_ diff --git a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc index 6e5e52d6..c5037435 100644 --- a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc +++ b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc @@ -25,6 +25,7 @@ #include "nnacl/op_base.h" #include "src/common/common.h" #include "src/common/log_util.h" +#include "ir/tensor_api.h" namespace mindspore::infer::mindir { static mindspore::HashMap kDefaultValueSwitchMap{ @@ -85,13 +86,13 @@ mindspore::ValuePtr MindirModelUtil::MakeValueFromTensorAttribute(const mind_ir: for (int i = 0; i < tensor_proto.dims_size(); i++) { shape.push_back(tensor_proto.dims(i)); } - tensor::TensorPtr tensor = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor::TensorPtr tensor = tensor::empty(kDefaultValueSwitchMap[attr_tensor_type], shape, device::DeviceType::kCPU); MS_EXCEPTION_IF_NULL(tensor); const std::string &tensor_buf = tensor_proto.raw_data(); if (tensor_proto.has_raw_data()) { auto *tensor_data_buf = reinterpret_cast(tensor->data_c()); - auto ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + auto ret = memcpy_s(tensor_data_buf, tensor->DataNBytes(), tensor_buf.data(), tensor_buf.size()); MS_CHECK_TRUE_MSG( ret != mindspore::lite::RET_OK, nullptr, "MindirModelUtil: Generate tensor ptr from tensor proto failed, failed to get tensor from tensor proto."); diff --git a/mindspore-lite/src/extendrt/session/ascend_native_session.cc b/mindspore-lite/src/extendrt/session/ascend_native_session.cc index de211534..6dc450d0 100644 --- a/mindspore-lite/src/extendrt/session/ascend_native_session.cc +++ b/mindspore-lite/src/extendrt/session/ascend_native_session.cc @@ -27,6 +27,7 @@ #include "extendrt/delegate/ascend_native/delegate.h" #include "src/common/log_adapter.h" #include "src/litert/cxx_api/converters.h" +#include "ir/device_address_maker.h" #include "ir/graph_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" @@ -403,7 +404,7 @@ std::vector AscendNativeSession::LiteTensorToTensor() std::vector shape64; std::transform(shape.begin(), shape.end(), std::back_inserter(shape64), [](int dim) { return static_cast(dim); }); - mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data); + mindspore::tensor::Tensor tensor(type_id, shape64, MakeDeviceAddress(type_id, shape64, ref_tensor_data)); tensors.emplace_back(std::move(tensor)); } return tensors; diff --git a/mindspore-lite/src/extendrt/session/default_session.cc b/mindspore-lite/src/extendrt/session/default_session.cc index faee07e1..72c305f5 100644 --- a/mindspore-lite/src/extendrt/session/default_session.cc +++ b/mindspore-lite/src/extendrt/session/default_session.cc @@ -29,6 +29,7 @@ #include "backend/graph_compiler/graph_partition.h" #include "common/tensor_util.h" #include "litert/cxx_api/tensor/tensor_impl.h" +#include "ir/device_address_maker.h" namespace mindspore { Status DefaultInferSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { @@ -338,7 +339,7 @@ std::vector DefaultInferSession::LiteTensorToTensor( std::transform(shape.begin(), shape.end(), std::back_inserter(shape64), [](int dim) { return static_cast(dim); }); - mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data); + mindspore::tensor::Tensor tensor(type_id, shape64, MakeDeviceAddress(type_id, shape64, ref_tensor_data)); auto device_address = abstract_tensor->device_data(); if (device_address != nullptr) { auto lite_device_address = std::make_shared(device_address, abstract_tensor->Size()); diff --git a/mindspore-lite/src/extendrt/session/single_op_session.cc b/mindspore-lite/src/extendrt/session/single_op_session.cc index 4cded112..ca9babc3 100644 --- a/mindspore-lite/src/extendrt/session/single_op_session.cc +++ b/mindspore-lite/src/extendrt/session/single_op_session.cc @@ -32,6 +32,7 @@ #include "src/extendrt/utils/kernel_build_utils.h" #include "src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" #include "src/common/common.h" +#include "ir/device_address_maker.h" #include "mindspore/ops/infer/custom.h" #include "extendrt/session/factory.h" #include "extendrt/utils/runtime_utils.h" @@ -409,7 +410,7 @@ void SingleOpInferSession::SetBackOutputIfDynamic(std::vector *o }; auto ref_tensor_data = std::make_shared(host_addr->addr, elem_num, host_addr->size, shape.size(), acl_mem_deleter); - (*outputs)[i] = tensor::Tensor(out_type, shape, ref_tensor_data); + (*outputs)[i] = tensor::Tensor(out_type, shape, MakeDeviceAddress(out_type, shape, ref_tensor_data)); MS_LOG(DEBUG) << "resetting kernel tensor shape to 0 for the next prediction"; kernel_args_.outputs[i]->SetShapeVector({0}); } @@ -434,7 +435,8 @@ Status SingleOpInferSession::InitInputOutputData(const std::vectorGetMutablePtr() != nullptr) { + if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr && + input_device_address->GetDeviceType() != device::DeviceType::kCPU) { auto device_ptr = input_device_address->GetMutablePtr(); kernel_args_.inputs[i]->SetData(std::make_shared(device_ptr, input.Size())); kernel_args_.inputs[i]->SetHostData(nullptr); @@ -446,7 +448,11 @@ Status SingleOpInferSession::InitInputOutputData(const std::vectorempty()) { std::transform(kernel_args_.outputs.begin(), kernel_args_.outputs.end(), std::back_inserter(*outputs), - [](auto &item) { return tensor::Tensor(item->dtype_id(), item->GetShapeVector()); }); + [](auto &item) { + return tensor::Tensor( + item->dtype_id(), item->GetShapeVector(), + MakeDeviceAddress(item->dtype_id(), item->GetShapeVector(), true, device::DeviceType::kCPU)); + }); } if (outputs->size() != kernel_args_.outputs.size()) { MS_LOG(ERROR) << "Given outputs size " << outputs->size() << " != graph inputs size " @@ -463,7 +469,8 @@ Status SingleOpInferSession::InitInputOutputData(const std::vectorGetMutablePtr() != nullptr) { + if (output_device_address != nullptr && output_device_address->GetMutablePtr() != nullptr && + output_device_address->GetDeviceType() != device::DeviceType::kCPU) { auto device_ptr = output_device_address->GetMutablePtr(); kernel_args_.outputs[i]->SetData(std::make_shared(device_ptr, output.Size())); kernel_args_.outputs[i]->SetHostData(nullptr); @@ -507,7 +514,8 @@ Status SingleOpInferSession::InitVariableWeights(const std::vector(data_type), shape); kernel_tensor->SetData(std::make_shared(input->data_c(), input->Size())); auto input_device_address = input->device_address(); - if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) { + if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr && + input_device_address->GetDeviceType() != device::DeviceType::kCPU) { auto device_ptr = input_device_address->GetMutablePtr(); kernel_tensor->SetData(std::make_shared(device_ptr, input->Size())); kernel_tensor->SetHostData(nullptr); diff --git a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc index 1bc434c8..f54fe973 100644 --- a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc +++ b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc @@ -35,6 +35,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" +#include "ir/tensor_api.h" namespace mindspore { const PrimitivePtr kPrimMakeTupleV2 = std::make_shared("make_tuple"); ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) { @@ -67,7 +68,7 @@ tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) { } if (value->isa()) { auto tensor = value->cast(); - if (tensor == nullptr || tensor->data().const_data() == nullptr) { + if (tensor == nullptr || tensor->unsafe_data() == nullptr) { return nullptr; } return tensor; @@ -326,7 +327,7 @@ void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, std: auto name = FuncGraphUtils::GetTensorName(tensor); auto data_type = FuncGraphUtils::GetTensorDataType(tensor); auto shape = FuncGraphUtils::GetTensorShape(tensor); - auto ms_tensor = std::make_shared(static_cast(data_type), shape); + auto ms_tensor = tensor::empty(static_cast(data_type), shape, device::DeviceType::kCPU); ms_tensor->set_name(name); inputs->push_back(ms_tensor); inputs_name->push_back(name); @@ -349,7 +350,7 @@ void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std auto name = FuncGraphUtils::GetTensorName(tensor); auto data_type = FuncGraphUtils::GetTensorDataType(tensor); auto shape = FuncGraphUtils::GetTensorShape(tensor); - auto ms_tensor = std::make_shared(static_cast(data_type), shape); + auto ms_tensor = tensor::empty(static_cast(data_type), shape, device::DeviceType::kCPU); ms_tensor->set_name(name); outputs->push_back(ms_tensor); output_names->push_back(name); diff --git a/mindspore-lite/src/extendrt/utils/tensor_utils.cc b/mindspore-lite/src/extendrt/utils/tensor_utils.cc index 3d1ce7d1..95b66e85 100644 --- a/mindspore-lite/src/extendrt/utils/tensor_utils.cc +++ b/mindspore-lite/src/extendrt/utils/tensor_utils.cc @@ -24,6 +24,7 @@ #include "common/common_utils.h" #include "mindspore/ccsrc/kernel/framework_utils.h" #include "common/format_utils.h" +#include "ir/device_address_maker.h" namespace mindspore { TensorRefData::TensorRefData(void *data, size_t bytes_size, size_t data_size, size_t ndim, @@ -51,7 +52,7 @@ ssize_t TensorRefData::ndim() const { return static_cast(ndim_); } void *TensorRefData::data() { return data_; } -const void *TensorRefData::const_data() const { return data_; } +void *TensorRefData::const_data() const { return data_; } std::string TensorRefData::ToString(TypeId type, const ShapeVector &shape, bool use_comma) const { std::stringstream stream; @@ -88,7 +89,8 @@ std::vector TensorUtils::MSTensorToTensorPtr(const auto data = ms_tensor.MutableData(); auto data_size = ms_tensor.DataSize(); auto ref_tensor_data = std::make_shared(data, ms_tensor.ElementNum(), data_size, shape.size()); - auto tensor_ptr = std::make_shared(type_id, shape, ref_tensor_data); + auto tensor_ptr = + std::make_shared(type_id, shape, MakeDeviceAddress(type_id, shape, ref_tensor_data)); tensor_ptr->set_name(ms_tensor.Name()); tensor_ptr->set_data_type(type_id); tensor_ptrs.push_back(tensor_ptr); @@ -118,7 +120,7 @@ std::vector TensorUtils::MSTensorToTensor(const std:: auto data = const_cast(ms_tensor.Data().get()); auto data_size = ms_tensor.DataSize(); auto ref_tensor_data = std::make_shared(data, ms_tensor.ElementNum(), data_size, shape.size()); - mindspore::tensor::Tensor tensor(type_id, shape, ref_tensor_data); + mindspore::tensor::Tensor tensor(type_id, shape, MakeDeviceAddress(type_id, shape, ref_tensor_data)); auto device_address = ms_tensor.GetDeviceData(); if (device_address != nullptr) { auto lite_device_address = std::make_shared(device_address, ms_tensor.DataSize()); diff --git a/mindspore-lite/src/extendrt/utils/tensor_utils.h b/mindspore-lite/src/extendrt/utils/tensor_utils.h index 79ef5c2b..f6b20173 100644 --- a/mindspore-lite/src/extendrt/utils/tensor_utils.h +++ b/mindspore-lite/src/extendrt/utils/tensor_utils.h @@ -47,9 +47,7 @@ class TensorRefData : public tensor::TensorData { ssize_t nbytes() const override; ssize_t ndim() const override; void *data() override; - const void *const_data() const override; - bool is_sub_data() const override { return false; } - bool has_sub_data() const override { return false; } + void *const_data() const override; std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override; private: @@ -135,7 +133,7 @@ class TensorTensorImpl : public MutableTensorImpl { void *GetDeviceData() override { MS_EXCEPTION_IF_NULL(tensor_); auto device_address = tensor_->device_address(); - if (device_address == nullptr) { + if (device_address == nullptr || tensor_->device_address()->GetDeviceType() == device::DeviceType::kCPU) { return nullptr; } return device_address->GetMutablePtr(); @@ -143,7 +141,8 @@ class TensorTensorImpl : public MutableTensorImpl { bool IsDevice() const override { MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->device_address() != nullptr; + return tensor_->device_address() != nullptr && + tensor_->device_address()->GetDeviceType() != device::DeviceType::kCPU; } bool IsConst() const override { return false; } diff --git a/mindspore-lite/test/common/import_from_meta_graphT.cc b/mindspore-lite/test/common/import_from_meta_graphT.cc index 72387b55..b3d19e6a 100644 --- a/mindspore-lite/test/common/import_from_meta_graphT.cc +++ b/mindspore-lite/test/common/import_from_meta_graphT.cc @@ -24,6 +24,7 @@ #include "include/errorcode.h" #include "src/common/utils.h" #include "tools/common/tensor_util.h" +#include "ir/tensor_api.h" namespace mindspore::lite { AnfNodePtr AnfImporterFromMetaGraphT::GetNode(int tensor_id) { @@ -56,7 +57,7 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { } else { parameter->set_name("const-" + std::to_string(i)); } - tensor::TensorPtr tensor_info = std::make_shared(type_id, shape_vector); + tensor::TensorPtr tensor_info = tensor::empty(type_id, shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; diff --git a/mindspore-lite/tools/common/custom_ascend_utils.cc b/mindspore-lite/tools/common/custom_ascend_utils.cc index f99ff767..d67e1ecf 100644 --- a/mindspore-lite/tools/common/custom_ascend_utils.cc +++ b/mindspore-lite/tools/common/custom_ascend_utils.cc @@ -19,6 +19,7 @@ #include "tools/common/func_graph_utils.h" #include "mindspore/ops/infer/tuple_get_item.h" #include "src/common/common.h" +#include "ir/tensor_api.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" @@ -90,7 +91,7 @@ ParameterPtr CustomAscendUtils::CreateOmParameter(const FuncGraphPtr &func_graph om_parameter->set_abstract(abstract_tensor); auto param_value = - std::make_shared(kNumberTypeUInt8, ShapeVector({static_cast(om_data.DataSize())})); + tensor::empty(kNumberTypeUInt8, ShapeVector({static_cast(om_data.DataSize())}), device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr."); auto tensor_data = param_value->data_c(); MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed."); @@ -173,7 +174,7 @@ bool CustomAscendUtils::GetZeroValueRefDatas(const ops::PrimitiveCPtr &primc, auto param_name = GetValue(value_ptr_list[i]); auto data_type = static_cast(GetValue(value_ptr_list[i + 1])); auto param_shape = GetValue(value_ptr_list[i + 2]); - auto tensor = std::make_shared(data_type, param_shape); + auto tensor = tensor::empty(data_type, param_shape, device::DeviceType::kCPU); ref_infos->push_back(std::make_pair(param_name, tensor)); } return true; diff --git a/mindspore-lite/tools/common/tensor_util.cc b/mindspore-lite/tools/common/tensor_util.cc index 319a682d..41e21f23 100644 --- a/mindspore-lite/tools/common/tensor_util.cc +++ b/mindspore-lite/tools/common/tensor_util.cc @@ -20,6 +20,7 @@ #include "tools/common/graph_util.h" #include "abstract/utils.h" #include "nnacl/op_base.h" +#include "ir/tensor_api.h" namespace mindspore::lite { namespace { @@ -76,14 +77,14 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std tensor::TensorPtr tensor_info = nullptr; if (shape.empty() && data_size == mindspore::abstract::TypeIdSize(data_type)) { ShapeVector scalar_shape = {1}; - tensor_info = std::make_shared(data_type, scalar_shape); + tensor_info = tensor::empty(data_type, scalar_shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor init failed"; return nullptr; } tensor_info->set_shape({}); } else { - tensor_info = std::make_shared(data_type, shape); + tensor_info = tensor::empty(data_type, shape, device::DeviceType::kCPU); } if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor init failed"; @@ -97,7 +98,7 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std return nullptr; } MS_CHECK_TRUE_MSG(tensor_info->Size() == data_size, nullptr, "invalid const tensor"); - auto ret = memcpy_s(tensor_info->data_c(), tensor_info->data().nbytes(), data, data_size); + auto ret = memcpy_s(tensor_info->data_c(), tensor_info->DataNBytes(), data, data_size); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s error : " << ret; return nullptr; @@ -149,7 +150,7 @@ int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t return RET_ERROR; } MS_CHECK_TRUE_MSG(tensor_info->Size() == data_size, RET_ERROR, "invalid const tensor"); - auto ret = memcpy_s(tensor_info->data_c(), tensor_info->data().nbytes(), data, data_size); + auto ret = memcpy_s(tensor_info->data_c(), tensor_info->DataNBytes(), data, data_size); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s error : " << ret; return RET_ERROR; @@ -191,10 +192,12 @@ int UpdateTensorTFromTensorInfo(const tensor::TensorPtr &src_tensor, std::unique (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), [](const int64_t &value) { return static_cast(value); }); schema_tensor->dims = dims; - if (src_tensor->data().data() != nullptr) { - schema_tensor->data.resize(src_tensor->data().nbytes()); - if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), src_tensor->data().data(), - src_tensor->data().nbytes())) { + auto src_device = src_tensor->device_address(); + if (src_device != nullptr && src_device->GetMutablePtr() != nullptr && + src_device->GetDeviceType() != device::DeviceType::kCPU) { + auto data_ptr = src_device->GetMutablePtr(); + schema_tensor->data.resize(src_tensor->DataNBytes()); + if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), data_ptr, src_tensor->DataNBytes())) { MS_LOG(ERROR) << "memcpy_s failed."; return RET_ERROR; } diff --git a/mindspore-lite/tools/converter/CMakeLists.txt b/mindspore-lite/tools/converter/CMakeLists.txt index bbdf5a36..d243a751 100644 --- a/mindspore-lite/tools/converter/CMakeLists.txt +++ b/mindspore-lite/tools/converter/CMakeLists.txt @@ -188,7 +188,14 @@ endif() set(MODEL_LOADER_FRAMEWORK_SRC ${MODEL_LOADER_FRAMEWORK_SRC} ${SRC_DIR}/extendrt/mindir_loader/model_loader.cc + ${SRC_DIR}/extendrt/lite_device_address.cc ) +if(NOT MSLITE_ENABLE_CONVERTER) + set(MODEL_LOADER_FRAMEWORK_SRC + ${MODEL_LOADER_FRAMEWORK_SRC} + ${CCSRC_DIR}/runtime/device/res_manager/utils/convert_tensor_utils.cc + ) +endif() if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE) diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc index a531ad4e..aff636f1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -154,11 +154,13 @@ Status AscendGraphImpl::ExecuteModel(const std::vector &request, std:: MS_LOG(ERROR) << "Execute Model Failed"; return kMCFailed; } + + std::vector outputs_cpu; for (const auto &out : outputs) { MS_EXCEPTION_IF_NULL(out); - out->data_sync(); + outputs_cpu.push_back(out->cpu()); } - last_outputs_ = outputs; + last_outputs_ = outputs_cpu; reply->clear(); *reply = GetOutputs(); return kSuccess; diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc index af7014f8..c563c68d 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -194,11 +194,13 @@ Status GPUGraphImpl::ExecuteModel(const std::vector &request, std::vec MS_LOG(ERROR) << "Execute Model Failed"; return kMCFailed; } + + std::vector outputs_cpu; for (const auto &out : outputs) { MS_EXCEPTION_IF_NULL(out); - out->data_sync(); + outputs_cpu.push_back(out->cpu()); } - last_outputs_ = outputs; + last_outputs_ = outputs_cpu; reply->clear(); *reply = GetOutputs(); return kSuccess; @@ -294,11 +296,9 @@ std::vector GPUGraphImpl::GetOutputs() { size_t data_size = tensor->Size(); if (i < last_outputs_.size()) { MS_EXCEPTION_IF_NULL(last_outputs_[i]); - if (last_outputs_[i]->NeedSyncDeviceToHost()) { - last_outputs_[i]->data_sync(false); - } - data = last_outputs_[i]->data_c(); - data_size = last_outputs_[i]->Size(); + auto cpu_tensor = last_outputs_[i]->cpu(); + data = cpu_tensor->data_c(); + data_size = cpu_tensor->Size(); } result[i] = MSTensor(output_names_[i], static_cast(tensor->data_type()), tensor->shape(), data, data_size); diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h index b0f1cb3c..652e00f4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h @@ -30,6 +30,7 @@ #include "backend/ms_backend/ms_backend.h" #include "backend/backend_manager/backend_jit_config.h" +#include "ir/tensor_api.h" namespace mindspore { class GraphCell::GraphImpl { public: @@ -118,7 +119,7 @@ class GraphCell::GraphImpl { auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); MS_EXCEPTION_IF_NULL(kernel_build_info); auto data_type = kernel_build_info->GetOutputDeviceType(0); - auto ms_tensor = std::make_shared(data_type, input_shape); + auto ms_tensor = tensor::empty(data_type, input_shape, device::DeviceType::kCPU); inputs->push_back(ms_tensor); inputs_name->push_back(parameter->name()); } diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc index fade3580..78475737 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc @@ -28,6 +28,7 @@ #include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" #include "cxx_api/model/acl/acl_vm/acl_vm.h" +#include "ir/tensor_api.h" namespace mindspore { API_MODEL_REG(Ascend310, AclModelMulti); @@ -184,7 +185,7 @@ void AclModelMulti::SetInputs() { auto elem = tensor_abs->element(); MS_EXCEPTION_IF_NULL(elem); auto type_id = elem->BuildType()->type_id(); - auto tensor = std::make_shared(type_id, tensor_shape->shape()); + auto tensor = tensor::empty(type_id, tensor_shape->shape(), device::DeviceType::kCPU); std::vector shape = tensor->shape_c(); auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast(tensor->data_type_c()), diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/squeeze_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/squeeze_mapper.cc index 3e418ad4..13c5de32 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/squeeze_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/squeeze_mapper.cc @@ -51,7 +51,7 @@ bool SqueezeMapper::GetAxisValue(AnfNodePtr input_node, std::vector *ax } if (value->isa()) { auto tensor = value->cast(); - if (tensor == nullptr || tensor->data().const_data() == nullptr) { + if (tensor == nullptr || tensor->unsafe_data() == nullptr) { return false; } if (tensor->data_type() == kNumberTypeInt64) { diff --git a/mindspore-lite/tools/converter/export_model.cc b/mindspore-lite/tools/converter/export_model.cc index 18b38208..e8163158 100644 --- a/mindspore-lite/tools/converter/export_model.cc +++ b/mindspore-lite/tools/converter/export_model.cc @@ -39,6 +39,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -112,7 +113,7 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const } std::shared_ptr tensor_info; if (static_cast(data_info.compress_type_) == TensorCompressionType::kNoCompression) { - tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec); + tensor_info = tensor::empty(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); } else { tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec, data_info.data_.size(), @@ -121,11 +122,11 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr); if (!data_info.data_.empty()) { auto tensor_data = reinterpret_cast(tensor_info->data_c()); - if (tensor_data == nullptr || tensor_info->data().nbytes() < 0) { + if (tensor_data == nullptr || tensor_info->DataNBytes() < 0) { MS_LOG(ERROR) << "tensor info data is nullptr or the size is smaller than zero."; return nullptr; } - if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) { + if (memcpy_s(tensor_data, tensor_info->DataNBytes(), data_info.data_.data(), data_info.data_.size()) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; return nullptr; } diff --git a/mindspore-lite/tools/converter/import/mindir_adjust.cc b/mindspore-lite/tools/converter/import/mindir_adjust.cc index 8c727c8f..29075a35 100644 --- a/mindspore-lite/tools/converter/import/mindir_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_adjust.cc @@ -31,6 +31,7 @@ #include "infer/fake_quant_param.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -204,17 +205,17 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) { MS_CHECK_TRUE_MSG(utils::cast(abstract_tensor->BuildShape()) != nullptr, RET_NULL_PTR, "Failed to cast pointer."); auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - auto dest_tensor_info = std::make_shared(kNumberTypeInt32, shape_vector); + auto dest_tensor_info = tensor::empty(kNumberTypeInt32, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(dest_tensor_info != nullptr, RET_NULL_PTR, "dest_tensor_info is nullptr."); MS_CHECK_TRUE_MSG(dest_tensor_info->data_c() != nullptr, RET_ERROR, "dest_tensor_info->data_c() is nullptr"); - MS_CHECK_TRUE_MSG(dest_tensor_info->data().nbytes() >= static_cast(sizeof(int32_t)), RET_ERROR, + MS_CHECK_TRUE_MSG(dest_tensor_info->DataNBytes() >= static_cast(sizeof(int32_t)), RET_ERROR, "num_bits_tensor->data_c() is not longer enough for int32_t"); auto *dest_data_buf = reinterpret_cast(dest_tensor_info->data_c()); MS_CHECK_TRUE_MSG(dest_data_buf != nullptr, RET_NULL_PTR, "dest_data_buf is nullptr."); auto src_tensor_info = value->cast(); MS_CHECK_TRUE_MSG(src_tensor_info != nullptr, RET_NULL_PTR, "src_tensor_info is nullptr."); MS_CHECK_TRUE_MSG(src_tensor_info->data_c() != nullptr, RET_ERROR, "src_tensor_info->data_c() is nullptr"); - MS_CHECK_TRUE_MSG(src_tensor_info->data().nbytes() >= static_cast(sizeof(int64_t)), RET_ERROR, + MS_CHECK_TRUE_MSG(src_tensor_info->DataNBytes() >= static_cast(sizeof(int64_t)), RET_ERROR, "num_bits_tensor->data_c() is not longer enough for int64_t"); auto *src_data_buf = reinterpret_cast(src_tensor_info->data_c()); MS_CHECK_TRUE_MSG(dest_tensor_info->ElementsNum() == src_tensor_info->ElementsNum(), RET_ERROR, diff --git a/mindspore-lite/tools/converter/offline_packing_optimizer.cc b/mindspore-lite/tools/converter/offline_packing_optimizer.cc index f7d610cb..c2532388 100644 --- a/mindspore-lite/tools/converter/offline_packing_optimizer.cc +++ b/mindspore-lite/tools/converter/offline_packing_optimizer.cc @@ -155,7 +155,7 @@ STATUS CreateLiteTensor(const CNodePtr &cnode, std::vector *in_tensors auto param_node = cnode->input(i)->cast(); if (param_node->has_default()) { auto tensor_info = std::static_pointer_cast(param_node->default_param()); - tensor_data = tensor_info->data().data(); + tensor_data = tensor_info->device_address()->GetMutablePtr(); auto quantization_params = tensor_info->quant_params(); if (!quantization_params.empty()) { auto quantization_param = quantization_params.front(); diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 9c3b0608..3f9283da 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -26,6 +26,7 @@ #include "tools/common/tensor_util.h" #include "nnacl/op_base.h" +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -55,7 +56,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t return RET_ERROR; } std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); - tensor_info = std::make_shared(data_type, shape_vector); + tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create tensor_info return nullptr"); std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index 3366376a..60e0c071 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -53,6 +53,7 @@ #include "tools/converter/parser/einsum_adjust.h" using mindspore::converter::kFmkTypeOnnx; +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -272,7 +273,7 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor return RET_ERROR; } } else { - tensor_info = std::make_shared(data_type, shape_vector); + tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr"); std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc index 7d914229..cc9c4bab 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -24,6 +24,7 @@ #include "src/common/file_utils.h" #include "utils/ms_utils_secure.h" +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -111,7 +112,7 @@ tensor::TensorPtr OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &on return nullptr; } std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); - auto tensor_info = std::make_shared(data_type, shape_vector); + auto tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new a tensor::Tensor failed, data type: " << data_type << ", shape: " << shape_vector; return nullptr; @@ -140,8 +141,8 @@ tensor::TensorPtr OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &on MS_LOG(ERROR) << "Dst tensor cannot be nullptr"; return nullptr; } - auto dst_bytes_size = tensor_info->data().nbytes(); - if (dst_bytes_size != SizeToLong(data_size)) { + auto dst_bytes_size = tensor_info->DataNBytes(); + if (dst_bytes_size != data_size) { MS_LOG(ERROR) << "Calculated data size " << data_size << " != tensor bytes size " << dst_bytes_size; return nullptr; } @@ -303,10 +304,10 @@ STATUS OnnxNodeParser::LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_ return RET_MEMORY_FAILED; } auto tensor_data = reinterpret_cast(tensor_info->data_c()); - if (common::huge_memcpy(tensor_data, static_cast(tensor_info->data().nbytes()), + if (common::huge_memcpy(tensor_data, static_cast(tensor_info->DataNBytes()), static_cast(onnx_data), data_size) != EOK) { MS_LOG(ERROR) << "memcpy_s from onnx tensor data to mindspore tensor data failed, dst size " - << tensor_info->data().nbytes() << ", src size " << data_size; + << tensor_info->DataNBytes() << ", src size " << data_size; return RET_ERROR; } return RET_OK; @@ -349,7 +350,7 @@ static int CopyOnnxData(void *dst_v, const void *src_v, size_t data_count) { int OnnxNodeParser::GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t data_count, const tensor::TensorPtr &tensor_info) { - auto data_size = LongToSize(tensor_info->data().nbytes()); + auto data_size = LongToSize(tensor_info->DataNBytes()); auto tensor_data = tensor_info->data_c(); auto onnx_data = onnx_const_tensor.raw_data().data(); if (onnx_const_tensor.raw_data().size() != data_size) { diff --git a/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc index ce4793d7..fb4187cb 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc @@ -55,6 +55,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" using mindspore::converter::kFmkTypeTf; +#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -494,7 +495,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co for (int i = 0; i < tensor_shape.dim_size(); i++) { shape_vector->push_back(tensor_shape.dim(i).size()); } - auto tensor_info = std::make_shared(type, *shape_vector); + auto tensor_info = tensor::empty(type, *shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "tensor info is nullptr"; return RET_ERROR; diff --git a/mindspore-lite/tools/converter/quantizer/cluster_quantization.cc b/mindspore-lite/tools/converter/quantizer/cluster_quantization.cc index 0947bc00..eb06a7ef 100644 --- a/mindspore-lite/tools/converter/quantizer/cluster_quantization.cc +++ b/mindspore-lite/tools/converter/quantizer/cluster_quantization.cc @@ -189,7 +189,7 @@ int ClusterQuantization::KMeansQuantization(const CNodePtr &cnode, const std::ve MS_LOG(INFO) << "This op " << parameter->fullname_with_scope() << " is bias"; continue; } - auto data = static_cast(tensor_info->data().data()); + auto data = static_cast(tensor_info->device_address()->GetMutablePtr()); std::vector cluster_centroid; std::vector clusters; auto ret = KMeans(data, tensor_info->DataSize(), k_, max_epochs_, tol_error_, &clusters, &cluster_centroid); diff --git a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.cc b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.cc index bdda9f91..24d5625a 100644 --- a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.cc +++ b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.cc @@ -211,13 +211,13 @@ int GptqQuantizer::UpdateWeightNode(const FuncGraphPtr &func_graph, MS_CHECK_TRUE_MSG(weight_tensor != nullptr, RET_ERROR, "default_param can not cast to tensor::Tensor."); weight_tensor->set_data_type(kNumberTypeInt8); size_t new_size = weights.at(weight_tensor_name)->elements_num * sizeof(int8_t); - if (new_size != static_cast(weight_tensor->data().nbytes())) { + if (new_size != static_cast(weight_tensor->DataNBytes())) { MS_LOG(ERROR) << "Data size of tensor info is error, new_size: " << new_size - << ", weight nbytes: " << static_cast(weight_tensor->data().nbytes()); + << ", weight nbytes: " << static_cast(weight_tensor->DataNBytes()); return RET_ERROR; } - if (memcpy_s(weight_tensor->data_c(), weight_tensor->data().nbytes(), - weights.at(weight_tensor_name)->quant_data, new_size) != EOK) { + if (memcpy_s(weight_tensor->data_c(), weight_tensor->DataNBytes(), weights.at(weight_tensor_name)->quant_data, + new_size) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return RET_ERROR; } diff --git a/mindspore-lite/tools/converter/quantizer/huffman_encode.cc b/mindspore-lite/tools/converter/quantizer/huffman_encode.cc index 666a0c66..5b52a4cc 100644 --- a/mindspore-lite/tools/converter/quantizer/huffman_encode.cc +++ b/mindspore-lite/tools/converter/quantizer/huffman_encode.cc @@ -50,11 +50,11 @@ int HuffmanEncode::DoHuffmanEncode(const tensor::TensorPtr &weight, const Primit } size_t ch_size = huffman_encoded_str_.length(); if (ch_size < packed_size) { - if (ch_size != static_cast(weight->data().nbytes())) { + if (ch_size != static_cast(weight->DataNBytes())) { MS_LOG(ERROR) << "Data size of weight is error."; return RET_ERROR; } - if (memcpy_s(weight->data_c(), weight->data().nbytes(), huffman_encoded_str_.c_str(), ch_size) != EOK) { + if (memcpy_s(weight->data_c(), weight->DataNBytes(), huffman_encoded_str_.c_str(), ch_size) != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; return RET_MEMORY_FAILED; } diff --git a/mindspore-lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc b/mindspore-lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc index dc7475ee..a3c6da58 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc @@ -104,7 +104,7 @@ int TransformUint8Pass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame // transform weight data size_t elem_count = tensor_info->DataSize(); - auto ret = Uint8toInt8(static_cast(tensor_info->data().data()), elem_count); + auto ret = Uint8toInt8(static_cast(tensor_info->device_address()->GetMutablePtr()), elem_count); if (ret != RET_OK) { MS_LOG(ERROR) << input_node->fullname_with_scope() << " transform data uint8 to int8 failed."; return ret; diff --git a/mindspore-lite/tools/converter/quantizer/quantize_util.cc b/mindspore-lite/tools/converter/quantizer/quantize_util.cc index f9860299..d61f3394 100644 --- a/mindspore-lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore-lite/tools/converter/quantizer/quantize_util.cc @@ -454,11 +454,11 @@ int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &wei MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR); MS_CHECK_TRUE_RET(new_size > 0, RET_NULL_PTR); weight->set_data_type(new_data_type); - if (new_size != static_cast(weight->data().nbytes())) { + if (new_size != static_cast(weight->DataNBytes())) { MS_LOG(ERROR) << "Data size of tensor info is error."; return RET_ERROR; } - if (memcpy_s(weight->data_c(), weight->data().nbytes(), quant_datas, new_size) != EOK) { + if (memcpy_s(weight->data_c(), weight->DataNBytes(), quant_datas, new_size) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return RET_ERROR; } diff --git a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc index af42e0c5..7addf1e4 100644 --- a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc +++ b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc @@ -23,6 +23,7 @@ #include "tools/converter/quantizer/quantize_util.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" +#include "ir/tensor_api.h" namespace mindspore::lite::quant { AnfNodePtr SplitSharedBias::CloneParameterNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &func_graph, @@ -50,7 +51,7 @@ AnfNodePtr SplitSharedBias::CloneParameterNode(const CNodePtr &cnode, size_t ind } std::shared_ptr tensor_info; if (static_cast(data_info.compress_type_) == TensorCompressionType::kNoCompression) { - tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec); + tensor_info = tensor::empty(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); } else { tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec, data_info.data_.size(), @@ -59,11 +60,11 @@ AnfNodePtr SplitSharedBias::CloneParameterNode(const CNodePtr &cnode, size_t ind MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr); if (!data_info.data_.empty()) { auto tensor_data = reinterpret_cast(tensor_info->data_c()); - if (tensor_data == nullptr || tensor_info->data().nbytes() < 0) { + if (tensor_data == nullptr || tensor_info->DataNBytes() < 0) { MS_LOG(ERROR) << "tensor info data is nullptr or the size is smaller than zero."; return nullptr; } - if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) { + if (memcpy_s(tensor_data, tensor_info->DataNBytes(), data_info.data_.data(), data_info.data_.size()) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; return nullptr; } diff --git a/mindspore-lite/tools/converter/quantizer/tensor_compressor.cc b/mindspore-lite/tools/converter/quantizer/tensor_compressor.cc index 92df29fa..5b3cae23 100644 --- a/mindspore-lite/tools/converter/quantizer/tensor_compressor.cc +++ b/mindspore-lite/tools/converter/quantizer/tensor_compressor.cc @@ -104,7 +104,7 @@ int TensorCompressor::SetNewCompressionTensor(const ParameterPtr &weight, const // set quant param compression_tensor->set_quant_param(tensor_info->quant_params()); // update tensor data - WriteBufferWithAlignByte(bits, static_cast(compression_tensor->data().data())); + WriteBufferWithAlignByte(bits, static_cast(compression_tensor->device_address()->GetMutablePtr())); weight->set_default_param(compression_tensor); weight->set_abstract(compression_tensor->ToAbstract()); return RET_OK; @@ -116,7 +116,7 @@ int TensorCompressor::DoBitPack(const ParameterPtr &weight, size_t bit_num) { auto elements_num = tensor_info->ElementsNum(); std::shared_ptr compression_tensor = nullptr; if (bit_num > 0 && bit_num < k8Bit) { - auto quant_data = static_cast(tensor_info->data().data()); + auto quant_data = static_cast(tensor_info->device_address()->GetMutablePtr()); std::vector origin_data(quant_data, quant_data + elements_num); std::vector pack_data{}; BitPack::BitPacking(bit_num, origin_data, &pack_data); @@ -130,7 +130,7 @@ int TensorCompressor::DoBitPack(const ParameterPtr &weight, size_t bit_num) { return RET_ERROR; } } else if (bit_num > k8Bit && bit_num < k16Bit) { - auto quant_data = static_cast(tensor_info->data().data()); + auto quant_data = static_cast(tensor_info->device_address()->GetMutablePtr()); std::vector origin_data(quant_data, quant_data + elements_num); std::vector pack_data{}; BitPack::BitPacking(bit_num, origin_data, &pack_data); diff --git a/mindspore-lite/tools/converter/quantizer/tensor_compressor.h b/mindspore-lite/tools/converter/quantizer/tensor_compressor.h index 8b8e427c..6d807e82 100644 --- a/mindspore-lite/tools/converter/quantizer/tensor_compressor.h +++ b/mindspore-lite/tools/converter/quantizer/tensor_compressor.h @@ -50,7 +50,7 @@ class TensorCompressor { return RET_OK; } auto max_size = tensor_info->Size(); - auto quant_data_array = static_cast(tensor_info->data().data()); + auto quant_data_array = static_cast(tensor_info->device_address()->GetMutablePtr()); std::vector quant_data(quant_data_array, quant_data_array + max_size / sizeof(T)); auto elem_cnt = quant_data.size(); @@ -128,7 +128,7 @@ class TensorCompressor { auto tensor_info = weight->default_param()->cast(); CHECK_NULL_RETURN(tensor_info); auto max_size = tensor_info->ElementsNum(); - auto quant_data = static_cast(tensor_info->data().data()); + auto quant_data = static_cast(tensor_info->device_address()->GetMutablePtr()); // write the index: each index has unique_value_bit unsigned for (int i = 0; i < max_size; i++) { auto quant_value = quant_data[i]; @@ -157,7 +157,7 @@ class TensorCompressor { size_t nz_cnt, size_t coor_best_bit, size_t bit_num) { auto tensor_info = weight->default_param()->cast(); CHECK_NULL_RETURN(tensor_info); - auto quant_data = static_cast(tensor_info->data().data()); + auto quant_data = static_cast(tensor_info->device_address()->GetMutablePtr()); int elem_cnt = tensor_info->DataSize(); auto channel_cnt = quant_params.size(); if (channel_cnt == 0) { diff --git a/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc b/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc index 7f7ff179..e27b0c9e 100644 --- a/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc +++ b/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc @@ -59,7 +59,7 @@ std::pair GetTransposeFormat(const CNodePtr &cnode) { return GetLiteFormat(cnode); } auto perm_tensor = perm_para->default_param()->cast(); - auto perm = static_cast(perm_tensor->data_ptr()->data()); + auto perm = static_cast(perm_tensor->device_address()->GetMutablePtr()); std::transform(perm, perm + perm_tensor->shape()[0], std::back_inserter(perm_list), IntToLong); } else { auto perm_value = cnode->input(perm_idx)->cast(); diff --git a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc index 6f12ed94..e745e31d 100644 --- a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc +++ b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc @@ -21,6 +21,7 @@ #include "utils/anf_utils.h" #include "backend/common/graph_kernel/core/graph_kernel_callback.h" #include "backend/common/graph_kernel/core/graph_kernel_utils.h" +#include "ir/tensor_api.h" namespace mindspore::graphkernel { constexpr size_t kConv2dDataIndex = 1; @@ -115,7 +116,7 @@ AnfNodePtr SubstituteConv2D::InferWeightValue(const AnfNodePtr &node) { if (tensor == nullptr) { return nullptr; } - if (tensor->data().const_data() == nullptr) { + if (tensor->unsafe_data() == nullptr) { return nullptr; } if (tensor->data_type() != kNumberTypeFloat32) { @@ -129,7 +130,7 @@ AnfNodePtr SubstituteConv2D::InferWeightValue(const AnfNodePtr &node) { IndexCalc old_shape_calc({c_out_o, c_out_i, h_len, w_len, c_in_o, c_in_i}); ShapeVector new_shape = {c_out_o, c_in_o, h_len, w_len, c_in_i, c_out_i}; IndexCalc new_shape_calc(new_shape); - auto new_tensor = std::make_shared(tensor->data_type(), new_shape); + auto new_tensor = tensor::empty(tensor->data_type(), new_shape, device::DeviceType::kCPU); auto new_data = new_tensor->data_c(); auto old_data = tensor->data_c(); for (int64_t coo = 0; coo < c_out_o; coo++) { @@ -188,7 +189,7 @@ AnfNodePtr MatmulPackB::InferValue(const AnfNodePtr &node) { if (tensor == nullptr) { return node; } - if (tensor->data().const_data() == nullptr) { + if (tensor->unsafe_data() == nullptr) { return node; } @@ -230,7 +231,7 @@ tensor::TensorPtr MatmulPackB::PackB(const tensor::TensorPtr &tensor, const Shap if (transpose) { std::swap(height, width); } - auto new_tensor = std::make_shared(tensor->data_type(), std::vector{height, width}); + auto new_tensor = tensor::empty(tensor->data_type(), std::vector{height, width}, device::DeviceType::kCPU); auto *new_tensor_iter = static_cast(new_tensor->data_c()); int64_t width_offset = 0; for (auto pack : pack_size) { diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.cc b/mindspore-lite/tools/lite_exporter/fetch_content.cc index 97d6a973..0b904a96 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.cc +++ b/mindspore-lite/tools/lite_exporter/fetch_content.cc @@ -62,7 +62,10 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap std::string shape_size_str; *offset = 0; size_t cnt = 0; - for (; *offset < tensor_info->Size(); (*offset)++) { + MS_EXCEPTION_IF_NULL(tensor_info->device_address()); + MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); + auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + for (; *offset < tensor_info_nbytes; (*offset)++) { if (tensor_data[*offset] == ',') { (*offset)++; break; @@ -76,7 +79,7 @@ STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, Shap constexpr int kBase = 10; size_t shape_size = static_cast(std::strtol(shape_size_str.c_str(), nullptr, kBase)); MS_CHECK_TRUE_RET(shape_size != 0, RET_ERROR); - for (; *offset < tensor_info->Size(); (*offset)++) { + for (; *offset < tensor_info_nbytes; (*offset)++) { if (tensor_data[*offset] == ',') { cnt++; int64_t shape = 0; @@ -159,8 +162,11 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, converter::FmkType fmk_ // process weight tensor if (copy_data) { - data_info->data_.resize(data->Size()); - if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) { + MS_EXCEPTION_IF_NULL(data->device_address()); + MS_EXCEPTION_IF_NULL(data->device_address()->data()); + auto data_nbytes = static_cast(data->device_address()->data()->nbytes()); + data_info->data_.resize(data_nbytes); + if (data_nbytes > 0 && memcpy_s(data_info->data_.data(), data_nbytes, data->data_c(), data_nbytes) != EOK) { MS_LOG(ERROR) << "memcpy_s error."; return RET_ERROR; } @@ -260,11 +266,14 @@ int SetTensorData(const tensor::TensorPtr &tensor_info, DataInfo *data_info, Typ bool copy_data) { MS_CHECK_TRUE_RET(data_info != nullptr, RET_NULL_PTR); MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_NULL_PTR); - if (data_type == kObjectTypeTensorType && tensor_info->Size() >= kTensorListMinSize) { - data_info->data_.resize(tensor_info->Size() - offset); + MS_EXCEPTION_IF_NULL(tensor_info->device_address()); + MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); + auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + if (data_type == kObjectTypeTensorType && tensor_info_nbytes >= kTensorListMinSize) { + data_info->data_.resize(tensor_info_nbytes - offset); if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(), static_cast(tensor_info->data_c()) + offset, - tensor_info->Size() - offset)) { + tensor_info_nbytes - offset)) { MS_LOG(ERROR) << "memcpy_s failed."; return RET_ERROR; } @@ -272,10 +281,10 @@ int SetTensorData(const tensor::TensorPtr &tensor_info, DataInfo *data_info, Typ // common node with const data if (data_type != kObjectTypeTensorType) { if (copy_data) { - data_info->data_.resize(tensor_info->Size() - offset); + data_info->data_.resize(tensor_info_nbytes - offset); if (EOK != common::huge_memcpy(data_info->data_.data(), data_info->data_.size(), static_cast(tensor_info->data_c()) + offset, - tensor_info->Size() - offset)) { + tensor_info_nbytes - offset)) { MS_LOG(ERROR) << "memcpy_s failed."; return RET_ERROR; } @@ -309,7 +318,10 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy } std::vector dims(shape_vector.begin(), shape_vector.end()); data_info->shape_ = dims; - if (tensor_info != nullptr && tensor_info->Size() != 0) { + MS_EXCEPTION_IF_NULL(tensor_info->device_address()); + MS_EXCEPTION_IF_NULL(tensor_info->device_address()->data()); + auto tensor_info_nbytes = static_cast(tensor_info->device_address()->data()->nbytes()); + if (tensor_info != nullptr && tensor_info_nbytes != 0) { // tensor_list tensor status = SetTensorData(tensor_info, data_info, data_type, offset, copy_data); if (status != RET_OK) { @@ -444,10 +456,12 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, DataInfo *data_info) } auto tensor_value = tensor_info->cast(); MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed"); - if (tensor_value->Size() >= kTensorListMinSize) { - data_info->data_.resize(tensor_value->Size()); - if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) != - EOK) { + MS_EXCEPTION_IF_NULL(tensor_value->device_address()); + MS_EXCEPTION_IF_NULL(tensor_value->device_address()->data()); + auto tensor_value_nbytes = static_cast(tensor_value->device_address()->data()->nbytes()); + if (tensor_value_nbytes >= kTensorListMinSize) { + data_info->data_.resize(tensor_value_nbytes); + if (memcpy_s(data_info->data_.data(), tensor_value_nbytes, tensor_value->data_c(), tensor_value_nbytes) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return RET_ERROR; } @@ -509,10 +523,12 @@ int FetchDataFromAbstract(const AbstractBasePtr &abstract, DataInfo *data_info) } auto tensor_value = tensor_info->cast(); MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed"); - if (tensor_value->Size() >= kTensorListMinSize) { - data_info->data_.resize(tensor_value->Size()); - if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) != - EOK) { + MS_EXCEPTION_IF_NULL(tensor_value->device_address()); + MS_EXCEPTION_IF_NULL(tensor_value->device_address()->data()); + auto tensor_value_nbytes = static_cast(tensor_value->device_address()->data()->nbytes()); + if (tensor_value_nbytes >= kTensorListMinSize) { + data_info->data_.resize(tensor_value_nbytes); + if (memcpy_s(data_info->data_.data(), tensor_value_nbytes, tensor_value->data_c(), tensor_value_nbytes) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return RET_ERROR; } diff --git a/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc b/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc index 8a7cc81a..f012260a 100644 --- a/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc +++ b/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc @@ -411,7 +411,7 @@ int MindIRSerializer::SaveMindIRTogether(const std::shared_ptr &p } auto data = para->default_param()->cast(); param_proto.clear_raw_data(); - param_proto.set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + param_proto.set_raw_data(data->data_c(), static_cast(data->DataNBytes())); } return SaveProtoToFile(&model_proto_, save_model_path_, param); @@ -562,7 +562,7 @@ int MindIRSerializer::SplitSave(const std::shared_ptr ¶m) { continue; } auto data = para->default_param()->cast(); - int64_t data_length = static_cast(data->data().nbytes()); + int64_t data_length = static_cast(data->DataNBytes()); int64_t append_size = 0; if (data_length % OFFSET != 0) { append_size = OFFSET - (data_length % OFFSET); diff --git a/mindspore-lite/tools/optimizer/common/format_utils.cc b/mindspore-lite/tools/optimizer/common/format_utils.cc index ba90ce18..c2e93a2c 100644 --- a/mindspore-lite/tools/optimizer/common/format_utils.cc +++ b/mindspore-lite/tools/optimizer/common/format_utils.cc @@ -82,6 +82,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" +#include "ir/tensor_api.h" namespace mindspore { namespace opt { // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw. @@ -341,7 +342,7 @@ int SetAbstractTensorInfo(const AbstractBasePtr &abstract) { TypeId type = lite::GetAbstractTensorDtype(abstract->cast()); // For kObjectTypeTensorType, the abstract value is TensorList amd does not need to reset. if (type != kObjectTypeTensorType) { - auto tensor_info = std::make_shared(type, shape); + auto tensor_info = tensor::empty(type, shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed"; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.cc b/mindspore-lite/tools/optimizer/common/gllo_utils.cc index 98759ae1..0860fc05 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.cc @@ -57,6 +57,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" namespace mindspore { namespace opt { namespace { @@ -881,7 +882,7 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::Te } param_node->set_name(node_name); param_node->debug_info()->set_name(node_name); - auto tensor_info_new = std::make_shared(data_type, shape_vector); + auto tensor_info_new = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); if (tensor_info_new == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed."; return nullptr; diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc b/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc index 3114350c..3b0e0ee6 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc @@ -43,6 +43,7 @@ using mindspore::lite::KernelRegistry; using mindspore::lite::Tensor; +#include "ir/tensor_api.h" namespace mindspore { namespace opt { namespace { @@ -57,7 +58,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); + auto tensor_info = tensor::empty(tensor->data_type(), shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "create tensor info failed."; return nullptr; diff --git a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc index bfad59fe..3986b060 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -34,6 +34,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" namespace mindspore::opt { namespace { @@ -240,7 +241,7 @@ int ResetReshapeParameters(const AnfNodePtr &reshape_node) { shape[0] = rmatmul_input_shape[0] + 1; } - auto tensor_info = std::make_shared(shape_tensor->data_type(), shape); + auto tensor_info = tensor::empty(shape_tensor->data_type(), shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "Create tensor info failed"; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc index a95d560a..22b9684d 100644 --- a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc @@ -364,7 +364,7 @@ STATUS DecoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float if (value_node->isa()) { auto tensor = value_node->cast(); MS_EXCEPTION_IF_NULL(tensor); - *eps = *reinterpret_cast(tensor->data().data()); + *eps = *reinterpret_cast(tensor->device_address()->GetMutablePtr()); return RET_OK; } } diff --git a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc index 84f8d794..6990f102 100644 --- a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc @@ -928,7 +928,7 @@ STATUS EncoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float if (value_node->isa()) { auto tensor = value_node->cast(); MS_EXCEPTION_IF_NULL(tensor); - *eps = *reinterpret_cast(tensor->data().data()); + *eps = *reinterpret_cast(tensor->device_address()->GetMutablePtr()); return RET_OK; } } @@ -1045,7 +1045,7 @@ STATUS EncoderLayerFusion::InitAttributes(AnfNodePtr k_past, AnfNodePtr begin_ex auto expert_capacity_value_node = utils::cast(utils::cast(expert_capacity_node)->value()); if (expert_capacity_value_node->isa()) { auto tensor = expert_capacity_value_node->cast(); - auto expert_capacity = *(reinterpret_cast(tensor->data().data())); + auto expert_capacity = *(reinterpret_cast(tensor->device_address()->GetMutablePtr())); float cast_expert_capacity = Float16::ToFloat32(expert_capacity); *capacity_factor = (cast_expert_capacity) * (*expert_num) / seq; } diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc index 4734e065..af5eadf1 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc @@ -73,7 +73,7 @@ const BaseRef KVCacheMgrOneBranchFusion::DefinePattern() const { tensor::TensorPtr KVCacheMgrOneBranchFusion::ConstData(int32_t padding_length) const { std::vector shp = {padding_length}; - tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); + tensor::TensorPtr const_data = tensor::empty(kInt32->type_id(), shp, device::DeviceType::kCPU); MS_CHECK_TRUE_RET(const_data != nullptr && const_data->data_c() != nullptr, nullptr); auto *val = static_cast(const_data->data_c()); for (int i = 0; i < padding_length; ++i) { diff --git a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index 71f13cd2..22f348f2 100644 --- a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -42,6 +42,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" namespace mindspore::opt { namespace { @@ -656,12 +657,12 @@ std::shared_ptr ConcatTensors(const std::vector(base_data_type, new_shape); + auto concat_tensor = tensor::empty(base_data_type, new_shape, device::DeviceType::kCPU); MS_CHECK_TRUE_RET(concat_tensor != nullptr, nullptr); std::size_t offset = 0; for (const auto &tensor : tensors) { void *ptr = reinterpret_cast(concat_tensor->data_c()) + offset; - auto transpose_tensor = std::make_shared(base_data_type, tensor->shape()); + auto transpose_tensor = tensor::empty(base_data_type, tensor->shape(), device::DeviceType::kCPU); if (transpose && !transpose_b) { switch (base_data_type) { case kNumberTypeFloat32: { @@ -692,7 +693,7 @@ std::shared_ptr ConcatTensors(const std::vector tshape = {new_shape[1], new_shape[0]}; - auto transposed_tensor = std::make_shared(base_data_type, tshape); + auto transposed_tensor = tensor::empty(base_data_type, tshape, device::DeviceType::kCPU); switch (base_data_type) { case kNumberTypeFloat32: { auto status = TransposeMatrix(concat_tensor, transposed_tensor); diff --git a/mindspore-lite/tools/optimizer/fusion/reduce_same_op_in_horizon.cc b/mindspore-lite/tools/optimizer/fusion/reduce_same_op_in_horizon.cc index 82cc54b9..c38cad83 100644 --- a/mindspore-lite/tools/optimizer/fusion/reduce_same_op_in_horizon.cc +++ b/mindspore-lite/tools/optimizer/fusion/reduce_same_op_in_horizon.cc @@ -39,10 +39,7 @@ bool CheckValueIsEqual(const ValuePtr &left, const ValuePtr &right) { auto left_tensor = left->cast(); auto right_tensor = right->cast(); MS_CHECK_TRUE_RET(left_tensor != nullptr && right_tensor != nullptr, false); - auto left_data = left_tensor->data_ptr(); - auto right_data = right_tensor->data_ptr(); - MS_CHECK_TRUE_RET(left_data != nullptr && right_data != nullptr, false); - return left_tensor->tensor::MetaTensor::operator==(*right_tensor) && left_data->equals(*right_data); + return left_tensor->ValueEqual(*right_tensor); } return *left == *right; } diff --git a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 81c66cac..54859d04 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -42,6 +42,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" +#include "ir/tensor_api.h" namespace mindspore { namespace opt { namespace { @@ -456,7 +457,7 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun } parameter->set_abstract(abstract); - auto gate_weight_default = std::make_shared(type, shape_vector); + auto gate_weight_default = tensor::empty(type, shape_vector, device::DeviceType::kCPU); if (gate_weight_default == nullptr) { MS_LOG(ERROR) << "gate_weight_default is nullptr"; return nullptr; diff --git a/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc b/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc index a9299897..066cb09d 100644 --- a/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc @@ -34,6 +34,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" namespace mindspore::opt { #if !defined(_WIN32) && !defined(_WIN64) @@ -104,7 +105,7 @@ void GroupedMatmulOpPass::UseEmptyNodeReplaceNone(const FuncGraphPtr &graph, con // create empty tensor auto tensor_type = OpInputDtypeMap.at(cnode_name).at(input_idx); std::vector tensor_shape = {0}; - auto empty_tensor = std::make_shared(tensor_type, tensor_shape); + auto empty_tensor = tensor::empty(tensor_type, tensor_shape, device::DeviceType::kCPU); // create node auto empty_node = std::make_shared(empty_tensor); ValueNodePtr empty_value_node = empty_node->cast(); diff --git a/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc b/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc index df781e77..63f0e7f4 100644 --- a/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc @@ -29,6 +29,7 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/common/func_graph_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" +#include "ir/tensor_api.h" namespace mindspore::opt { @@ -154,7 +155,7 @@ CNodePtr InputAndOutputVariablePass::CreateAssign(const AnfNodePtr &anf_node, co MS_LOG(ERROR) << "type ptr is nullptr"; return nullptr; } - tensor::TensorPtr tensor_data = std::make_shared(type_ptr->type_id(), shape); + tensor::TensorPtr tensor_data = tensor::empty(type_ptr->type_id(), shape, device::DeviceType::kCPU); float *val = static_cast(tensor_data->data_c()); for (size_t i = 0; i < tensor_data->DataSize(); ++i) { *(val + i) = 0; diff --git a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc index 7d2539af..b86ae3b1 100644 --- a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc +++ b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc @@ -216,7 +216,7 @@ int LiteTensorExtractor::GetCNodeConstInputToAbstract(const CNodePtr &cnode, con } auto input_tensor = shape_value->cast(); MS_CHECK_FALSE(input_tensor == nullptr, RET_ERROR); - if (input_tensor->data().const_data() != nullptr) { + if (input_tensor->unsafe_data() != nullptr) { MS_LOG(DEBUG) << "abstract already have const data."; continue; } diff --git a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc index 892be53b..155cca61 100644 --- a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc @@ -38,23 +38,22 @@ static inline tensor::TensorPtr GetTensorFromNode(const AnfNodePtr &node) { return nullptr; } auto tensor = value->cast(); - if (tensor == nullptr || tensor->data_ptr() == nullptr || tensor->data_c() == nullptr) { + if (tensor == nullptr || tensor->device_address() == nullptr || tensor->data_c() == nullptr) { return nullptr; } return tensor; } bool MiniaturizationPass::NeedCompress(const tensor::TensorPtr &tensor) { - auto tensor_data_ptr = tensor->data_ptr(); - auto item_size = tensor_data_ptr->itemsize(); - auto item_num = tensor_data_ptr->size(); - auto data_ptr = tensor_data_ptr->data(); + auto item_size = tensor->DataItemSize(); + auto item_num = tensor->DataSize(); + auto data_ptr = tensor->device_address()->GetMutablePtr(); // No need cast to fill ops while tensor data size is small. if (item_num < COMPRESS_TRIGGER_SIZE_) { return false; } int ret = 0; - for (ssize_t idx = 1; idx < item_num; idx++) { + for (size_t idx = 1; idx < item_num; idx++) { auto offset = idx * item_size; // No memcmp_s provide in secure lib of huawei ret = memcmp(static_cast(data_ptr) + offset, static_cast(data_ptr) + offset - item_size, @@ -67,15 +66,14 @@ bool MiniaturizationPass::NeedCompress(const tensor::TensorPtr &tensor) { } static inline ValuePtr GetFirstVal(const tensor::TensorPtr &tensor) { - auto tensor_data_ptr = tensor->data_ptr(); + auto tensor_data_ptr = tensor->device_address()->GetMutablePtr(); auto data_type = tensor->data_type(); - auto data_ptr = tensor_data_ptr->data(); if (data_type == kNumberTypeFloat32) { - float val = static_cast(data_ptr)[0]; + float val = static_cast(tensor_data_ptr)[0]; return MakeValue(val); } if (data_type == kNumberTypeUInt32) { - int32_t val = static_cast(data_ptr)[0]; + int32_t val = static_cast(tensor_data_ptr)[0]; return MakeValue(val); } return nullptr; diff --git a/mindspore-lite/tools/optimizer/graph/node_infershape.cc b/mindspore-lite/tools/optimizer/graph/node_infershape.cc index 69d3edf9..0ddc5d92 100644 --- a/mindspore-lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore-lite/tools/optimizer/graph/node_infershape.cc @@ -58,6 +58,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_z.h" +#include "ir/tensor_api.h" namespace mindspore { namespace opt { static const std::unordered_set kNNACLToOpsInfer = { @@ -168,7 +169,7 @@ void RectifyFormat(const std::vector &inputs, FmkType fmk_type) tensor::TensorPtr NewTensorInfo(const lite::Tensor *tensor) { std::vector shape(tensor->shape()); std::vector shape_vector(shape.begin(), shape.end()); - auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); + auto tensor_info = tensor::empty(tensor->data_type(), shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed"; return nullptr; diff --git a/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc b/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc index 0fb225ab..5ce17203 100644 --- a/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc @@ -29,6 +29,7 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/common/func_graph_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" +#include "ir/tensor_api.h" namespace mindspore::opt { namespace { @@ -78,7 +79,7 @@ bool OutputVariablePass::Run(const FuncGraphPtr &graph) { } abstract::ShapePtr shape = dyn_cast(make_tuple_input->Shape()); MS_CHECK_TRUE_MSG(shape != nullptr, false, "shape is nullptr!"); - tensor::TensorPtr tensor_data = std::make_shared(type_ptr->type_id(), shape->shape()); + tensor::TensorPtr tensor_data = tensor::empty(type_ptr->type_id(), shape->shape(), device::DeviceType::kCPU); float *data_addr = static_cast(tensor_data->data_c()); for (size_t j = 0; i < tensor_data->DataSize(); ++j) { diff --git a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc index acf9c930..82aa86fc 100644 --- a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc @@ -35,6 +35,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" +#include "ir/tensor_api.h" /* This pass changes the following pattern(s). @@ -119,7 +120,7 @@ ValueNodePtr ScalarOpPass::GenerateScalarValueTensor(const FuncGraphPtr &func_gr } int32_t scalar_value = *reinterpret_cast(data_info.data_.data()); ShapeVector const_data_shape = {1}; - tensor::TensorPtr const_data_tensor = std::make_shared(kNumberTypeInt32, const_data_shape); + tensor::TensorPtr const_data_tensor = tensor::empty(kNumberTypeInt32, const_data_shape, device::DeviceType::kCPU); auto *val = static_cast(const_data_tensor->data_c()); *val = scalar_value; auto const_value_node = NewValueNode(const_data_tensor); diff --git a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc index d997fcd7..6f988070 100644 --- a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc @@ -38,6 +38,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::schema::PrimitiveType_Conv2DFusion; +#include "ir/tensor_api.h" namespace mindspore { namespace opt { namespace { @@ -83,7 +84,7 @@ void CreateSplitConstantTensors(const tensor::TensorPtr &constant_tensor, const } auto cur_shape = UP_DIV(split_dim_size * visited_block, total_block_count); split_constant_shapes.at(i).at(split_dim) = cur_shape; - auto tensor = std::make_shared(weight_type_id, split_constant_shapes.at(i)); + auto tensor = tensor::empty(weight_type_id, split_constant_shapes.at(i), device::DeviceType::kCPU); if (tensor == nullptr) { MS_LOG(ERROR) << "make shared failed."; split_constant_tensors->clear(); -- Gitee From 6c95e45860d42df38a6db650a5b9a997c1ff1a39 Mon Sep 17 00:00:00 2001 From: liuf9 Date: Thu, 31 Jul 2025 15:33:53 +0800 Subject: [PATCH 2/7] merge tensor-storage-refactor 2 --- mindspore-lite/cmake/ccsrc_converter.cmake | 1 - mindspore-lite/src/CMakeLists.txt | 4 ++-- .../delegate/ascend_ge/ge_device_context.cc | 1 - .../delegate/tensorrt/tensorrt_graph_executor.cc | 3 ++- .../mindir_loader/mindir_model/mindir_model_util.cc | 4 +++- .../src/extendrt/utils/func_graph_utils.cc | 4 ++-- .../test/common/import_from_meta_graphT.cc | 4 ++-- mindspore-lite/tools/common/custom_ascend_utils.cc | 8 ++++---- mindspore-lite/tools/common/tensor_util.cc | 6 +++--- .../cxx_api/model/acl/model_converter.cc | 2 -- mindspore-lite/tools/converter/export_model.cc | 4 ++-- .../tools/converter/import/mindir_adjust.cc | 4 ++-- .../converter/parser/onnx/onnx_constant_parser.cc | 4 ++-- .../converter/parser/onnx/onnx_model_parser.cc | 4 ++-- .../tools/converter/parser/onnx/onnx_node_parser.cc | 4 ++-- .../converter/parser/pytorch/pytorch_lstm_adjust.cc | 3 ++- .../tools/converter/parser/tf/tf_model_parser.cc | 4 ++-- .../tools/converter/quantizer/quantize_util.cc | 13 +++++++------ .../tools/converter/quantizer/split_shared_bias.cc | 4 ++-- .../tools/converter/quantizer/weight_quantizer.cc | 5 +++-- .../tools/optimizer/common/format_utils.cc | 4 ++-- mindspore-lite/tools/optimizer/common/gllo_utils.cc | 4 ++-- .../tools/optimizer/const_fold/fold_utils.cc | 4 ++-- .../tools/optimizer/fusion/batchmatmul_fusion.cc | 4 ++-- .../fusion/kv_cache_mgr_one_branch_fusion.cc | 4 +++- .../optimizer/fusion/multi_head_attention_fusion.cc | 8 ++++---- .../optimizer/fusion/tf_bidirection_gru_fusion.cc | 4 ++-- .../optimizer/graph/decrease_transpose_algo.cc | 5 +++-- .../tools/optimizer/graph/grouped_matmul_op_pass.cc | 4 ++-- .../graph/input_and_output_variable_pass.cc | 4 ++-- .../tools/optimizer/graph/node_infershape.cc | 6 +++--- .../tools/optimizer/graph/output_variable_pass.cc | 4 ++-- .../tools/optimizer/graph/scalar_op_pass.cc | 6 +++--- .../optimizer/graph/send_op_add_control_depend.cc | 3 ++- .../optimizer/parallel/depthwise_conv2d_info.cc | 4 ++-- 35 files changed, 80 insertions(+), 74 deletions(-) diff --git a/mindspore-lite/cmake/ccsrc_converter.cmake b/mindspore-lite/cmake/ccsrc_converter.cmake index 22480151..abddb41c 100644 --- a/mindspore-lite/cmake/ccsrc_converter.cmake +++ b/mindspore-lite/cmake/ccsrc_converter.cmake @@ -14,7 +14,6 @@ if(MSLITE_ENABLE_CONVERTER) set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../tools) set(CCSRC_SRC - ${CCSRC_DIR}/backend/backend_manager/backend_jit_config.cc ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc ${CCSRC_DIR}/backend/common/optimizer/visitor.cc ${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc diff --git a/mindspore-lite/src/CMakeLists.txt b/mindspore-lite/src/CMakeLists.txt index 84c54640..52019938 100644 --- a/mindspore-lite/src/CMakeLists.txt +++ b/mindspore-lite/src/CMakeLists.txt @@ -576,8 +576,8 @@ if(MSLITE_ENABLE_MULTI_LAYOUT) endif() if(MSLITE_ENABLE_RUNTIME_GLOG) - target_link_libraries(mindspore-lite mindspore::glog) - target_link_libraries(mindspore-lite_static mindspore::glog) + target_link_libraries(mindspore-lite mindspore::glog mindspore::securec) + target_link_libraries(mindspore-lite_static mindspore::glog mindspore::securec) endif() if(DEFINED ARCHS) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc index 9c4d6b34..8fc440b2 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc @@ -29,7 +29,6 @@ #include "common/common.h" #include "extendrt/delegate/comm_group_info.h" #include "extendrt/delegate/ascend_ge/ge_utils.h" -#include "backend/common/session/executor.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc index c1b507b4..18067216 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc @@ -37,6 +37,7 @@ #include "ir/device_address_maker.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" +#include "ir/tensor_new.h" namespace mindspore::lite { namespace { @@ -96,7 +97,7 @@ tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node) { if (type_ptr == nullptr) { return nullptr; } - return std::make_shared(static_cast(type_ptr->type_id()), type_ptr->type()); + return tensor::from_scalar(static_cast(type_ptr->type_id()), type_ptr->type()); } MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope(); return nullptr; diff --git a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc index c5037435..a1b40bc2 100644 --- a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc +++ b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc @@ -21,6 +21,7 @@ #include "extendrt/mindir_loader/mindir_model/mindir_model_util.h" #include "ir/tensor.h" #include "ir/value.h" +#include "ir/tensor_new.h" #include "include/errorcode.h" #include "nnacl/op_base.h" #include "src/common/common.h" @@ -86,7 +87,8 @@ mindspore::ValuePtr MindirModelUtil::MakeValueFromTensorAttribute(const mind_ir: for (int i = 0; i < tensor_proto.dims_size(); i++) { shape.push_back(tensor_proto.dims(i)); } - tensor::TensorPtr tensor = tensor::empty(kDefaultValueSwitchMap[attr_tensor_type], shape, device::DeviceType::kCPU); + tensor::TensorPtr tensor = + tensor::from_spec(kDefaultValueSwitchMap[attr_tensor_type], shape, device::DeviceType::kCPU); MS_EXCEPTION_IF_NULL(tensor); const std::string &tensor_buf = tensor_proto.raw_data(); diff --git a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc index f54fe973..84db140f 100644 --- a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc +++ b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc @@ -35,7 +35,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { const PrimitivePtr kPrimMakeTupleV2 = std::make_shared("make_tuple"); ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) { @@ -84,7 +84,7 @@ tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) { if (type_ptr == nullptr) { return nullptr; } - return std::make_shared(static_cast(type_ptr->type_id()), type_ptr->type()); + return tensor::from_scalar(static_cast(type_ptr->type_id()), type_ptr->type()); } MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope(); return nullptr; diff --git a/mindspore-lite/test/common/import_from_meta_graphT.cc b/mindspore-lite/test/common/import_from_meta_graphT.cc index b3d19e6a..8e0c6a16 100644 --- a/mindspore-lite/test/common/import_from_meta_graphT.cc +++ b/mindspore-lite/test/common/import_from_meta_graphT.cc @@ -24,7 +24,7 @@ #include "include/errorcode.h" #include "src/common/utils.h" #include "tools/common/tensor_util.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::lite { AnfNodePtr AnfImporterFromMetaGraphT::GetNode(int tensor_id) { @@ -57,7 +57,7 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() { } else { parameter->set_name("const-" + std::to_string(i)); } - tensor::TensorPtr tensor_info = tensor::empty(type_id, shape_vector, device::DeviceType::kCPU); + tensor::TensorPtr tensor_info = tensor::from_spec(type_id, shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "create tensor info failed."; return RET_ERROR; diff --git a/mindspore-lite/tools/common/custom_ascend_utils.cc b/mindspore-lite/tools/common/custom_ascend_utils.cc index d67e1ecf..8124115b 100644 --- a/mindspore-lite/tools/common/custom_ascend_utils.cc +++ b/mindspore-lite/tools/common/custom_ascend_utils.cc @@ -19,7 +19,7 @@ #include "tools/common/func_graph_utils.h" #include "mindspore/ops/infer/tuple_get_item.h" #include "src/common/common.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" @@ -90,8 +90,8 @@ ParameterPtr CustomAscendUtils::CreateOmParameter(const FuncGraphPtr &func_graph MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "abstract_tensor is nullptr."); om_parameter->set_abstract(abstract_tensor); - auto param_value = - tensor::empty(kNumberTypeUInt8, ShapeVector({static_cast(om_data.DataSize())}), device::DeviceType::kCPU); + auto param_value = tensor::from_spec(kNumberTypeUInt8, ShapeVector({static_cast(om_data.DataSize())}), + device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr."); auto tensor_data = param_value->data_c(); MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed."); @@ -174,7 +174,7 @@ bool CustomAscendUtils::GetZeroValueRefDatas(const ops::PrimitiveCPtr &primc, auto param_name = GetValue(value_ptr_list[i]); auto data_type = static_cast(GetValue(value_ptr_list[i + 1])); auto param_shape = GetValue(value_ptr_list[i + 2]); - auto tensor = tensor::empty(data_type, param_shape, device::DeviceType::kCPU); + auto tensor = tensor::from_spec(data_type, param_shape, device::DeviceType::kCPU); ref_infos->push_back(std::make_pair(param_name, tensor)); } return true; diff --git a/mindspore-lite/tools/common/tensor_util.cc b/mindspore-lite/tools/common/tensor_util.cc index 41e21f23..1e0131ad 100644 --- a/mindspore-lite/tools/common/tensor_util.cc +++ b/mindspore-lite/tools/common/tensor_util.cc @@ -20,7 +20,7 @@ #include "tools/common/graph_util.h" #include "abstract/utils.h" #include "nnacl/op_base.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::lite { namespace { @@ -77,14 +77,14 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std tensor::TensorPtr tensor_info = nullptr; if (shape.empty() && data_size == mindspore::abstract::TypeIdSize(data_type)) { ShapeVector scalar_shape = {1}; - tensor_info = tensor::empty(data_type, scalar_shape, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(data_type, scalar_shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor init failed"; return nullptr; } tensor_info->set_shape({}); } else { - tensor_info = tensor::empty(data_type, shape, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(data_type, shape, device::DeviceType::kCPU); } if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor init failed"; diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc index 37d583c9..1510d375 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc @@ -15,7 +15,6 @@ */ #include "cxx_api/model/acl/model_converter.h" - #include #include #include @@ -24,7 +23,6 @@ #include "graph/graph_buffer.h" #include "graph/graph.h" #include "cxx_api/model/aoe/auto_tune_process.h" -#include "plugin/device/ascend/optimizer/ge_optimization.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/acl_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" diff --git a/mindspore-lite/tools/converter/export_model.cc b/mindspore-lite/tools/converter/export_model.cc index e8163158..d422faa3 100644 --- a/mindspore-lite/tools/converter/export_model.cc +++ b/mindspore-lite/tools/converter/export_model.cc @@ -39,7 +39,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace lite { namespace { @@ -113,7 +113,7 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const } std::shared_ptr tensor_info; if (static_cast(data_info.compress_type_) == TensorCompressionType::kNoCompression) { - tensor_info = tensor::empty(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); } else { tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec, data_info.data_.size(), diff --git a/mindspore-lite/tools/converter/import/mindir_adjust.cc b/mindspore-lite/tools/converter/import/mindir_adjust.cc index 29075a35..3435900e 100644 --- a/mindspore-lite/tools/converter/import/mindir_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_adjust.cc @@ -31,7 +31,7 @@ #include "infer/fake_quant_param.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace lite { namespace { @@ -205,7 +205,7 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) { MS_CHECK_TRUE_MSG(utils::cast(abstract_tensor->BuildShape()) != nullptr, RET_NULL_PTR, "Failed to cast pointer."); auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - auto dest_tensor_info = tensor::empty(kNumberTypeInt32, shape_vector, device::DeviceType::kCPU); + auto dest_tensor_info = tensor::from_spec(kNumberTypeInt32, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(dest_tensor_info != nullptr, RET_NULL_PTR, "dest_tensor_info is nullptr."); MS_CHECK_TRUE_MSG(dest_tensor_info->data_c() != nullptr, RET_ERROR, "dest_tensor_info->data_c() is nullptr"); MS_CHECK_TRUE_MSG(dest_tensor_info->DataNBytes() >= static_cast(sizeof(int32_t)), RET_ERROR, diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 3f9283da..05d0eee9 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -25,8 +25,8 @@ #include "tools/converter/ops/ops_def.h" #include "tools/common/tensor_util.h" #include "nnacl/op_base.h" +#include "ir/tensor_new.h" -#include "ir/tensor_api.h" namespace mindspore { namespace lite { namespace { @@ -56,7 +56,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t return RET_ERROR; } std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); - tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(data_type, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create tensor_info return nullptr"); std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index 60e0c071..f9eeff78 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -53,7 +53,7 @@ #include "tools/converter/parser/einsum_adjust.h" using mindspore::converter::kFmkTypeOnnx; -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace lite { namespace { @@ -273,7 +273,7 @@ STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::Tensor return RET_ERROR; } } else { - tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(data_type, shape_vector, device::DeviceType::kCPU); MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr"); std::vector shape; std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc index cc9c4bab..549faaff 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -24,7 +24,7 @@ #include "src/common/file_utils.h" #include "utils/ms_utils_secure.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace lite { namespace { @@ -112,7 +112,7 @@ tensor::TensorPtr OnnxNodeParser::CopyOnnxTensorData(const onnx::TensorProto &on return nullptr; } std::vector shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); - auto tensor_info = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(data_type, shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new a tensor::Tensor failed, data type: " << data_type << ", shape: " << shape_vector; return nullptr; diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc index 16e1e7bd..b0a74206 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc @@ -26,6 +26,7 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { @@ -190,7 +191,7 @@ bool PytorchLstmAdjustPass::AdjustDataFormat(const ParameterPtr ¶meter) { return false; } - auto new_tensor = std::make_shared(weight->data_type(), weight->shape(), new_data, data_size); + auto new_tensor = tensor::from_buffer(weight->data_type(), weight->shape(), new_data, data_size); MS_CHECK_TRUE_RET(new_tensor != nullptr, false); parameter->set_default_param(new_tensor); diff --git a/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc index fb4187cb..4aca8046 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_model_parser.cc @@ -55,7 +55,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" using mindspore::converter::kFmkTypeTf; -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace lite { namespace { @@ -495,7 +495,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co for (int i = 0; i < tensor_shape.dim_size(); i++) { shape_vector->push_back(tensor_shape.dim(i).size()); } - auto tensor_info = tensor::empty(type, *shape_vector, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(type, *shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "tensor info is nullptr"; return RET_ERROR; diff --git a/mindspore-lite/tools/converter/quantizer/quantize_util.cc b/mindspore-lite/tools/converter/quantizer/quantize_util.cc index d61f3394..b7fd3fad 100644 --- a/mindspore-lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore-lite/tools/converter/quantizer/quantize_util.cc @@ -42,6 +42,7 @@ #include "src/common/file_utils.h" #include "src/litert/cxx_api/tensor/tensor_impl.h" #include "ir/anf.h" +#include "ir/tensor_new.h" #include "tools/converter/export_model.h" #include "tools/converter/parser/parser_utils.h" #include "mindspore/ops/op_def/other_ops.h" @@ -771,8 +772,8 @@ int ConvertCNodeFp32ToFp16(const CNodePtr &cnode) { for (size_t j = 0; j < tensor_info->DataSize(); j++) { fp16_data[j] = mindspore::Float16(data[j]); } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( - kNumberTypeFloat16, tensor_info->shape_c(), fp16_data.data(), fp16_data.size() * sizeof(float) / 2); + auto tensor_ptr = mindspore::tensor::from_buffer(kNumberTypeFloat16, tensor_info->shape_c(), fp16_data.data(), + fp16_data.size() * sizeof(float) / 2); param_node->set_default_param(tensor_ptr); param_node->set_abstract(tensor_ptr->ToAbstract()); } @@ -801,8 +802,8 @@ int ConvertCNodeFp32ToFp16(const CNodePtr &cnode) { for (int j = 0; j < total_size; j++) { fp16_data[j] = mindspore::Float16(data[j]); } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( - kNumberTypeFloat16, shapes, fp16_data.data(), fp16_data.size() * sizeof(float) / 2); + auto tensor_ptr = mindspore::tensor::from_buffer(kNumberTypeFloat16, shapes, fp16_data.data(), + fp16_data.size() * sizeof(float) / 2); auto values = MakeValue(tensor_ptr); value_node->set_value(values); value_node->set_abstract(tensor_ptr->ToAbstract()); @@ -842,8 +843,8 @@ int ConvertCNodeFp16ToFp32(const CNodePtr &cnode) { for (size_t j = 0; j < tensor_info->DataSize(); j++) { fp32_data[j] = mindspore::Float16::ToFloat32(data[j]); } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( - kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float)); + mindspore::tensor::TensorPtr tensor_ptr = tensor::from_buffer(kNumberTypeFloat32, tensor_info->shape_c(), + fp32_data.data(), fp32_data.size() * sizeof(float)); tensor::TensorPtr input_tensor = quant::GetNodeTensor(input); MS_CHECK_TRUE_MSG(input_tensor != nullptr, RET_NULL_PTR, "Get node tensor failed."); diff --git a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc index 7addf1e4..34f192a3 100644 --- a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc +++ b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc @@ -23,7 +23,7 @@ #include "tools/converter/quantizer/quantize_util.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::lite::quant { AnfNodePtr SplitSharedBias::CloneParameterNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &func_graph, @@ -51,7 +51,7 @@ AnfNodePtr SplitSharedBias::CloneParameterNode(const CNodePtr &cnode, size_t ind } std::shared_ptr tensor_info; if (static_cast(data_info.compress_type_) == TensorCompressionType::kNoCompression) { - tensor_info = tensor::empty(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); + tensor_info = tensor::from_spec(static_cast(data_info.data_type_), shape_vec, device::DeviceType::kCPU); } else { tensor_info = std::make_shared(static_cast(data_info.data_type_), shape_vec, data_info.data_.size(), diff --git a/mindspore-lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore-lite/tools/converter/quantizer/weight_quantizer.cc index 9d329feb..a1fd671a 100644 --- a/mindspore-lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore-lite/tools/converter/quantizer/weight_quantizer.cc @@ -48,6 +48,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" +#include "ir/tensor_new.h" namespace mindspore::lite::quant { namespace { @@ -68,8 +69,8 @@ tensor::TensorPtr ConvertParameterFp16TensorToFp32(const ParameterPtr ¶meter for (size_t j = 0; j < tensor_info->DataSize(); j++) { fp32_data[j] = mindspore::Float16::ToFloat32(data[j]); } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared( - kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float)); + auto tensor_ptr = mindspore::tensor::from_buffer(kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), + fp32_data.size() * sizeof(float)); parameter->set_default_param(tensor_ptr); parameter->set_abstract(tensor_ptr->ToAbstract()); return tensor_ptr; diff --git a/mindspore-lite/tools/optimizer/common/format_utils.cc b/mindspore-lite/tools/optimizer/common/format_utils.cc index c2e93a2c..6806a3ce 100644 --- a/mindspore-lite/tools/optimizer/common/format_utils.cc +++ b/mindspore-lite/tools/optimizer/common/format_utils.cc @@ -82,7 +82,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { // treat the weight of deformableConv2d as an input instead of a const because of the ops infershape only support nchw. @@ -342,7 +342,7 @@ int SetAbstractTensorInfo(const AbstractBasePtr &abstract) { TypeId type = lite::GetAbstractTensorDtype(abstract->cast()); // For kObjectTypeTensorType, the abstract value is TensorList amd does not need to reset. if (type != kObjectTypeTensorType) { - auto tensor_info = tensor::empty(type, shape, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(type, shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed"; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.cc b/mindspore-lite/tools/optimizer/common/gllo_utils.cc index 0860fc05..ce5bbb53 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.cc @@ -57,7 +57,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { namespace { @@ -882,7 +882,7 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const tensor::Te } param_node->set_name(node_name); param_node->debug_info()->set_name(node_name); - auto tensor_info_new = tensor::empty(data_type, shape_vector, device::DeviceType::kCPU); + auto tensor_info_new = tensor::from_spec(data_type, shape_vector, device::DeviceType::kCPU); if (tensor_info_new == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed."; return nullptr; diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc b/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc index 3b0e0ee6..6a4648e7 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_utils.cc @@ -43,7 +43,7 @@ using mindspore::lite::KernelRegistry; using mindspore::lite::Tensor; -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { namespace { @@ -58,7 +58,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto tensor_info = tensor::empty(tensor->data_type(), shape_vector, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(tensor->data_type(), shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "create tensor info failed."; return nullptr; diff --git a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc index 3986b060..f231088e 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -34,7 +34,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::opt { namespace { @@ -241,7 +241,7 @@ int ResetReshapeParameters(const AnfNodePtr &reshape_node) { shape[0] = rmatmul_input_shape[0] + 1; } - auto tensor_info = tensor::empty(shape_tensor->data_type(), shape, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(shape_tensor->data_type(), shape, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "Create tensor info failed"; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc index af5eadf1..07f6aaa3 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc @@ -37,6 +37,8 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" +#include "ir/tensor_new.h" + namespace mindspore::opt { const BaseRef KVCacheMgrOneBranchFusion::DefinePattern() const { if (!InitVar()) { @@ -73,7 +75,7 @@ const BaseRef KVCacheMgrOneBranchFusion::DefinePattern() const { tensor::TensorPtr KVCacheMgrOneBranchFusion::ConstData(int32_t padding_length) const { std::vector shp = {padding_length}; - tensor::TensorPtr const_data = tensor::empty(kInt32->type_id(), shp, device::DeviceType::kCPU); + tensor::TensorPtr const_data = tensor::from_spec(kInt32->type_id(), shp, device::DeviceType::kCPU); MS_CHECK_TRUE_RET(const_data != nullptr && const_data->data_c() != nullptr, nullptr); auto *val = static_cast(const_data->data_c()); for (int i = 0; i < padding_length; ++i) { diff --git a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index 22f348f2..06fac39c 100644 --- a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -42,7 +42,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::opt { namespace { @@ -657,12 +657,12 @@ std::shared_ptr ConcatTensors(const std::vector(concat_tensor->data_c()) + offset; - auto transpose_tensor = tensor::empty(base_data_type, tensor->shape(), device::DeviceType::kCPU); + auto transpose_tensor = tensor::from_spec(base_data_type, tensor->shape(), device::DeviceType::kCPU); if (transpose && !transpose_b) { switch (base_data_type) { case kNumberTypeFloat32: { @@ -693,7 +693,7 @@ std::shared_ptr ConcatTensors(const std::vector tshape = {new_shape[1], new_shape[0]}; - auto transposed_tensor = tensor::empty(base_data_type, tshape, device::DeviceType::kCPU); + auto transposed_tensor = tensor::from_spec(base_data_type, tshape, device::DeviceType::kCPU); switch (base_data_type) { case kNumberTypeFloat32: { auto status = TransposeMatrix(concat_tensor, transposed_tensor); diff --git a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 54859d04..8e845de1 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -42,7 +42,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { namespace { @@ -457,7 +457,7 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun } parameter->set_abstract(abstract); - auto gate_weight_default = tensor::empty(type, shape_vector, device::DeviceType::kCPU); + auto gate_weight_default = tensor::from_spec(type, shape_vector, device::DeviceType::kCPU); if (gate_weight_default == nullptr) { MS_LOG(ERROR) << "gate_weight_default is nullptr"; return nullptr; diff --git a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc index 5d984bf7..548f84c2 100644 --- a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -36,6 +36,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { @@ -241,8 +242,8 @@ int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, ShapeVector tmp_shape(DIMENSION_4D - expand_shape.size(), 1); (void)expand_shape.insert(expand_shape.begin(), tmp_shape.begin(), tmp_shape.end()); } - auto tensor = std::make_shared(static_cast(data_info.data_type_), expand_shape, - data_info.data_.data(), data_info.data_.size()); + auto tensor = tensor::from_buffer(static_cast(data_info.data_type_), expand_shape, data_info.data_.data(), + data_info.data_.size()); MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr"); if (trans_type == kNHWC2NCHW) { (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW); diff --git a/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc b/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc index 066cb09d..c9d079a1 100644 --- a/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/grouped_matmul_op_pass.cc @@ -34,7 +34,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::opt { #if !defined(_WIN32) && !defined(_WIN64) @@ -105,7 +105,7 @@ void GroupedMatmulOpPass::UseEmptyNodeReplaceNone(const FuncGraphPtr &graph, con // create empty tensor auto tensor_type = OpInputDtypeMap.at(cnode_name).at(input_idx); std::vector tensor_shape = {0}; - auto empty_tensor = tensor::empty(tensor_type, tensor_shape, device::DeviceType::kCPU); + auto empty_tensor = tensor::from_spec(tensor_type, tensor_shape, device::DeviceType::kCPU); // create node auto empty_node = std::make_shared(empty_tensor); ValueNodePtr empty_value_node = empty_node->cast(); diff --git a/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc b/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc index 63f0e7f4..5f79c591 100644 --- a/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/input_and_output_variable_pass.cc @@ -29,7 +29,7 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/common/func_graph_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::opt { @@ -155,7 +155,7 @@ CNodePtr InputAndOutputVariablePass::CreateAssign(const AnfNodePtr &anf_node, co MS_LOG(ERROR) << "type ptr is nullptr"; return nullptr; } - tensor::TensorPtr tensor_data = tensor::empty(type_ptr->type_id(), shape, device::DeviceType::kCPU); + tensor::TensorPtr tensor_data = tensor::from_spec(type_ptr->type_id(), shape, device::DeviceType::kCPU); float *val = static_cast(tensor_data->data_c()); for (size_t i = 0; i < tensor_data->DataSize(); ++i) { *(val + i) = 0; diff --git a/mindspore-lite/tools/optimizer/graph/node_infershape.cc b/mindspore-lite/tools/optimizer/graph/node_infershape.cc index 0ddc5d92..43222ffa 100644 --- a/mindspore-lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore-lite/tools/optimizer/graph/node_infershape.cc @@ -58,7 +58,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_z.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { static const std::unordered_set kNNACLToOpsInfer = { @@ -169,7 +169,7 @@ void RectifyFormat(const std::vector &inputs, FmkType fmk_type) tensor::TensorPtr NewTensorInfo(const lite::Tensor *tensor) { std::vector shape(tensor->shape()); std::vector shape_vector(shape.begin(), shape.end()); - auto tensor_info = tensor::empty(tensor->data_type(), shape_vector, device::DeviceType::kCPU); + auto tensor_info = tensor::from_spec(tensor->data_type(), shape_vector, device::DeviceType::kCPU); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed"; return nullptr; @@ -680,7 +680,7 @@ abstract::AbstractBasePtr NodeInferShape::ConvertTensorListToAbstract(lite::Tens } std::vector data_shape; data_shape.push_back(data_info.size()); - auto tensor_info = std::make_shared(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32); + auto tensor_info = tensor::from_buffer(kNumberTypeInt32, data_shape, data_info.data(), kNumberTypeInt32); if (tensor_info == nullptr) { MS_LOG(ERROR) << "new tensor::Tensor failed"; return nullptr; diff --git a/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc b/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc index 5ce17203..6c1cd583 100644 --- a/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/output_variable_pass.cc @@ -29,7 +29,7 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/common/func_graph_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::opt { namespace { @@ -79,7 +79,7 @@ bool OutputVariablePass::Run(const FuncGraphPtr &graph) { } abstract::ShapePtr shape = dyn_cast(make_tuple_input->Shape()); MS_CHECK_TRUE_MSG(shape != nullptr, false, "shape is nullptr!"); - tensor::TensorPtr tensor_data = tensor::empty(type_ptr->type_id(), shape->shape(), device::DeviceType::kCPU); + tensor::TensorPtr tensor_data = tensor::from_spec(type_ptr->type_id(), shape->shape(), device::DeviceType::kCPU); float *data_addr = static_cast(tensor_data->data_c()); for (size_t j = 0; i < tensor_data->DataSize(); ++j) { diff --git a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc index 82aa86fc..95d27415 100644 --- a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc @@ -35,7 +35,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" /* This pass changes the following pattern(s). @@ -120,7 +120,7 @@ ValueNodePtr ScalarOpPass::GenerateScalarValueTensor(const FuncGraphPtr &func_gr } int32_t scalar_value = *reinterpret_cast(data_info.data_.data()); ShapeVector const_data_shape = {1}; - tensor::TensorPtr const_data_tensor = tensor::empty(kNumberTypeInt32, const_data_shape, device::DeviceType::kCPU); + tensor::TensorPtr const_data_tensor = tensor::from_spec(kNumberTypeInt32, const_data_shape, device::DeviceType::kCPU); auto *val = static_cast(const_data_tensor->data_c()); *val = scalar_value; auto const_value_node = NewValueNode(const_data_tensor); @@ -206,7 +206,7 @@ CNodePtr ScalarOpPass::GenerateTensorShape(const FuncGraphPtr &func_graph, const } } else { auto shp_buf_size = sizeof(int64_t) * shape.size(); - auto tensor = std::make_shared(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size); + auto tensor = tensor::from_buffer(kNumberTypeInt64, tensor_shp, shape.data(), shp_buf_size); tmp_abstract = tensor->ToAbstract(); } diff --git a/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc b/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc index 720bed0b..c958e6a4 100644 --- a/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc +++ b/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc @@ -20,6 +20,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "include/common/utils/anfalgo.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { @@ -52,7 +53,7 @@ const AnfNodePtr SendOpAddControlDepend::Process(const FuncGraphPtr &func_graph, MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope() << ", input node: " << cnode->input(1)->fullname_with_scope(); - auto tensor = std::make_shared(0.0); + auto tensor = tensor::from_scalar(0.0); MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); auto value = MakeValue(tensor); diff --git a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc index 6f988070..dba686ed 100644 --- a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc @@ -38,7 +38,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::schema::PrimitiveType_Conv2DFusion; -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace opt { namespace { @@ -84,7 +84,7 @@ void CreateSplitConstantTensors(const tensor::TensorPtr &constant_tensor, const } auto cur_shape = UP_DIV(split_dim_size * visited_block, total_block_count); split_constant_shapes.at(i).at(split_dim) = cur_shape; - auto tensor = tensor::empty(weight_type_id, split_constant_shapes.at(i), device::DeviceType::kCPU); + auto tensor = tensor::from_spec(weight_type_id, split_constant_shapes.at(i), device::DeviceType::kCPU); if (tensor == nullptr) { MS_LOG(ERROR) << "make shared failed."; split_constant_tensors->clear(); -- Gitee From ed124944d143f15cca09066972e7997c33a43ef5 Mon Sep 17 00:00:00 2001 From: liuf9 Date: Mon, 30 Jun 2025 11:36:31 +0800 Subject: [PATCH 3/7] move nnacl to lite --- .jenkins/check/config/filter_cppcheck.txt | 4 + .jenkins/check/config/filter_cpplint.txt | 3 + .jenkins/check/config/whitelizard.txt | 205 + mindspore-lite/CMakeLists.txt | 6 +- mindspore-lite/java/native/CMakeLists.txt | 5 +- mindspore-lite/java/native/common/jni_utils.h | 2 +- mindspore-lite/minddata/CMakeLists.txt | 4 +- mindspore-lite/python/CMakeLists.txt | 2 +- mindspore-lite/src/CMakeLists.txt | 2 +- mindspore-lite/src/common/common.h | 2 +- mindspore-lite/src/common/graph_util.cc | 2 +- mindspore-lite/src/common/ops/CMakeLists.txt | 2 +- .../activation_grad_populate.cc | 2 +- .../operator_populate/activation_populate.cc | 2 +- .../ops/operator_populate/adder_populate.cc | 2 +- .../ops/operator_populate/affine_populate.cc | 2 +- .../operator_populate/all_gather_populate.cc | 2 +- .../operator_populate/arg_minmax_populate.cc | 2 +- .../arithmetic_operator_populate.h | 2 +- .../arithmetic_self_populate.cc | 2 +- .../operator_populate/attention_populate.cc | 2 +- .../audio_spectrogram_populate.cc | 2 +- .../base_operator_populate.cc | 20 +- .../operator_populate/batch_norm_populate.cc | 2 +- .../batch_to_space_populate.cc | 2 +- .../broadcast_to_populate.cc | 2 +- .../ops/operator_populate/call_populate.cc | 2 +- .../ops/operator_populate/clip_populate.cc | 2 +- .../ops/operator_populate/concat_populate.cc | 2 +- .../constant_of_shape_populate.cc | 2 +- .../ops/operator_populate/conv2d_populate.cc | 2 +- .../crop_and_resize_populate.cc | 2 +- .../ops/operator_populate/crop_populate.cc | 2 +- .../ops/operator_populate/cumsum_populate.cc | 2 +- .../ops/operator_populate/custom_populate.cc | 4 +- .../custom_predict_populate.cc | 2 +- .../operator_populate/deconv2d_populate.cc | 2 +- .../depth_to_space_populate.cc | 2 +- .../detection_post_process_populate.cc | 2 +- .../dynamic_quant_populate.cc | 2 +- .../embedding_lookup_populate.cc | 2 +- .../ops/operator_populate/exp_populate.cc | 2 +- .../ops/operator_populate/flatten_populate.cc | 2 +- .../full_connection_populate.cc | 2 +- .../fused_batchnorm_populate.cc | 2 +- .../ops/operator_populate/glu_populate.cc | 2 +- .../operator_populate/group_norm_populate.cc | 2 +- .../ops/operator_populate/gru_populate.cc | 2 +- .../instance_norm_populate.cc | 2 +- .../ops/operator_populate/l2_norm_populate.cc | 2 +- .../layer_norm_grad_populate.cc | 2 +- .../operator_populate/layer_norm_populate.cc | 2 +- .../local_response_normalization_populate.cc | 2 +- .../operator_populate/log_softmax_populate.cc | 2 +- .../lsh_projection_populate.cc | 2 +- .../ops/operator_populate/lstm_populate.cc | 2 +- .../ops/operator_populate/matmul_populate.cc | 2 +- .../ops/operator_populate/mfcc_populate.cc | 2 +- .../ops/operator_populate/nllloss_populate.cc | 2 +- .../non_max_suppression_populate.cc | 2 +- .../ops/operator_populate/one_hot_populate.cc | 2 +- .../operator_populate_register.h | 2 +- .../ops/operator_populate/p_relu_populate.cc | 2 +- .../ops/operator_populate/pad_populate.cc | 2 +- .../ops/operator_populate/partial_populate.cc | 2 +- .../ops/operator_populate/pooling_populate.cc | 2 +- .../ops/operator_populate/power_populate.cc | 2 +- .../operator_populate/prior_box_populate.cc | 2 +- .../quant_dtype_cast_populate.cc | 2 +- .../random_normal_populate.cc | 2 +- .../random_standard_normal_populate.cc | 2 +- .../ops/operator_populate/range_populate.cc | 2 +- .../ops/operator_populate/reduce_populate.cc | 2 +- .../ops/operator_populate/reduce_scatter.cc | 2 +- .../ops/operator_populate/resize_populate.cc | 2 +- .../ops/operator_populate/reverse_populate.cc | 2 +- .../reverse_sequence_populate.cc | 2 +- .../operator_populate/roi_pooling_populate.cc | 2 +- .../ops/operator_populate/scale_populate.cc | 2 +- .../scatter_element_populate.cc | 2 +- .../operator_populate/skip_gram_populate.cc | 2 +- .../ops/operator_populate/slice_populate.cc | 2 +- .../ops/operator_populate/softmax_populate.cc | 2 +- .../space_to_batch_nd_populate.cc | 2 +- .../space_to_batch_populate.cc | 2 +- .../space_to_depth_populate.cc | 2 +- ...tmax_cross_entropy_with_logits_populate.cc | 2 +- .../ops/operator_populate/splice_populate.cc | 4 +- .../ops/operator_populate/split_populate.cc | 4 +- .../split_with_overlap_populate.cc | 2 +- .../ops/operator_populate/squeeze_populate.cc | 2 +- .../stack_operator_populate.cc | 2 +- .../strided_slice_grad_populate.cc | 2 +- .../strided_slice_operator_populate.cc | 2 +- .../tensor_array_populate.cc | 2 +- .../tensor_list_from_tensor_populate.cc | 2 +- .../tensor_list_get_item_populate.cc | 2 +- .../tensor_list_reserve_populate.cc | 2 +- .../tensor_list_set_item_populate.cc | 2 +- .../tensor_list_stack_populate.cc | 2 +- .../tile_operator_populate.cc | 2 +- .../ops/operator_populate/topk_populate.cc | 2 +- .../uniform_real_populate.cc | 2 +- .../operator_populate/unsqueeze_populate.cc | 2 +- .../ops/operator_populate/unstack_populate.cc | 2 +- .../ops/populate/activation_grad_populate.cc | 2 +- .../ops/populate/activation_populate.cc | 2 +- .../src/common/ops/populate/adam_populate.cc | 2 +- .../src/common/ops/populate/add_populate.cc | 2 +- .../src/common/ops/populate/adder_populate.cc | 2 +- .../common/ops/populate/affine_populate.cc | 4 +- .../src/common/ops/populate/all_gather.cc | 2 +- .../common/ops/populate/argmax_populate.cc | 2 +- .../common/ops/populate/argmin_populate.cc | 2 +- .../common/ops/populate/arithmetic_populate.h | 2 +- .../ops/populate/arithmetic_self_populate.cc | 2 +- .../common/ops/populate/attention_populate.cc | 2 +- .../populate/audio_spectrogram_populate.cc | 2 +- .../ops/populate/batch_norm_populate.cc | 2 +- .../ops/populate/batch_to_space_populate.cc | 2 +- .../common/ops/populate/bias_add_populate.cc | 2 +- .../ops/populate/broadcast_to_populate.cc | 2 +- .../src/common/ops/populate/call_populate.cc | 2 +- .../src/common/ops/populate/clip_populate.cc | 2 +- .../common/ops/populate/concat_populate.cc | 2 +- .../populate/constant_of_shape_populate.cc | 2 +- .../populate/control/tensor_array_populate.cc | 4 +- .../control/tensorlistfromtensor_populate.cc | 2 +- .../control/tensorlistgetitem_populate.cc | 2 +- .../control/tensorlistreserve_populate.cc | 2 +- .../control/tensorlistsetlitem_populate.cc | 2 +- .../control/tensorliststack_populate.cc | 2 +- .../common/ops/populate/conv2d_populate.cc | 2 +- .../ops/populate/crop_and_resize_populate.cc | 2 +- .../src/common/ops/populate/crop_populate.cc | 2 +- .../common/ops/populate/cumsum_populate.cc | 2 +- .../common/ops/populate/custom_populate.cc | 16 +- .../common/ops/populate/deconv2d_populate.cc | 2 +- .../common/ops/populate/default_populate.h | 2 +- .../ops/populate/depth_to_space_populate.cc | 2 +- .../detection_post_process_populate.cc | 2 +- .../ops/populate/dynamic_quant_populate.cc | 2 +- .../ops/populate/embedding_lookup_populate.cc | 2 +- .../src/common/ops/populate/exp_populate.cc | 2 +- .../common/ops/populate/flatten_populate.cc | 2 +- .../ops/populate/full_connection_populate.cc | 2 +- .../ops/populate/fused_batchnorm_populate.cc | 2 +- .../common/ops/populate/gather_d_populate.cc | 2 +- .../common/ops/populate/gather_nd_populate.cc | 2 +- .../common/ops/populate/gather_populate.cc | 2 +- .../src/common/ops/populate/glu_populate.cc | 2 +- .../ops/populate/group_norm_populate.cc | 2 +- .../src/common/ops/populate/gru_populate.cc | 2 +- .../ops/populate/instance_norm_populate.cc | 2 +- .../common/ops/populate/l2_norm_populate.cc | 2 +- .../ops/populate/layer_norm_grad_populate.cc | 2 +- .../ops/populate/layer_norm_populate.cc | 2 +- .../local_response_normalization_populate.cc | 2 +- .../ops/populate/log_softmax_populate.cc | 2 +- .../src/common/ops/populate/lstm_populate.cc | 2 +- .../common/ops/populate/matmul_populate.cc | 2 +- .../src/common/ops/populate/mfcc_populate.cc | 2 +- .../src/common/ops/populate/mul_populate.cc | 2 +- .../common/ops/populate/nllloss_populate.cc | 2 +- .../populate/non_max_suppression_populate.cc | 2 +- .../common/ops/populate/one_hot_populate.cc | 2 +- .../common/ops/populate/p_relu_populate.cc | 2 +- .../src/common/ops/populate/pad_populate.cc | 2 +- .../common/ops/populate/partial_populate.cc | 2 +- .../common/ops/populate/pooling_populate.cc | 2 +- .../common/ops/populate/populate_register.h | 2 +- .../src/common/ops/populate/power_populate.cc | 2 +- .../common/ops/populate/prior_box_populate.cc | 2 +- .../ops/populate/quant_dtype_cast_populate.cc | 2 +- .../ops/populate/random_normal_populate.cc | 2 +- .../random_standard_normal_populate.cc | 2 +- .../src/common/ops/populate/range_populate.cc | 2 +- .../common/ops/populate/reduce_populate.cc | 2 +- .../src/common/ops/populate/reduce_scatter.cc | 2 +- .../common/ops/populate/reshape_populate.cc | 2 +- .../common/ops/populate/resize_populate.cc | 2 +- .../common/ops/populate/reverse_populate.cc | 2 +- .../ops/populate/reverse_sequence_populate.cc | 2 +- .../ops/populate/roi_pooling_populate.cc | 2 +- .../src/common/ops/populate/scale_populate.cc | 2 +- .../ops/populate/scatter_element_populate.cc | 2 +- .../ops/populate/scatter_nd_populate.cc | 2 +- .../populate/scatter_nd_update_populate.cc | 2 +- .../src/common/ops/populate/slice_populate.cc | 2 +- .../common/ops/populate/softmax_populate.cc | 2 +- .../populate/space_to_batch_nd_populate.cc | 2 +- .../ops/populate/space_to_batch_populate.cc | 2 +- .../ops/populate/space_to_depth_populate.cc | 2 +- ...tmax_cross_entropy_with_logits_populate.cc | 2 +- .../ops/populate/sparse_to_dense_populate.cc | 2 +- .../common/ops/populate/splice_populate.cc | 4 +- .../src/common/ops/populate/split_populate.cc | 4 +- .../populate/split_with_overlap_populate.cc | 2 +- .../common/ops/populate/squeeze_populate.cc | 2 +- .../src/common/ops/populate/stack_populate.cc | 2 +- .../populate/strided_slice_grad_populate.cc | 2 +- .../ops/populate/strided_slice_populate.h | 2 +- .../string/custom_predict_populate.cc | 2 +- .../string/lsh_projection_populate.cc | 2 +- .../ops/populate/string/skip_gram_populate.cc | 2 +- .../src/common/ops/populate/sub_populate.cc | 2 +- .../src/common/ops/populate/tile_populate.cc | 2 +- .../src/common/ops/populate/topk_populate.cc | 2 +- .../common/ops/populate/transpose_populate.cc | 2 +- .../common/ops/populate/triu_tril_populate.cc | 2 +- .../ops/populate/uniform_real_populate.cc | 2 +- .../common/ops/populate/unique_populate.cc | 2 +- .../common/ops/populate/unsqueeze_populate.cc | 2 +- .../common/ops/populate/unstack_populate.cc | 2 +- .../src/common/ops/populate/where_populate.cc | 2 +- mindspore-lite/src/common/prim_util.cc | 2 +- mindspore-lite/src/common/tensor_util.cc | 2 +- mindspore-lite/src/common/tensor_util.h | 8 +- .../control_flow/control_flow_scheduler.cc | 2 +- .../src/control_flow/control_flow_scheduler.h | 2 +- mindspore-lite/src/executor/kernel_exec.h | 2 +- .../src/executor/sub_graph_kernel.h | 2 +- .../src/extendrt/cxx_api/model/model_impl.cc | 2 +- .../model_pool/model_parallel_runner_impl.cc | 2 +- .../extendrt/cxx_api/model_pool/model_pool.cc | 2 +- .../cxx_api/model_pool/model_worker.cc | 2 +- .../cxx_api/model_pool/resource_manager.cc | 2 +- .../src/extendrt/delegate/delegate_utils.cc | 2 +- .../extendrt/delegate/tensorrt/CMakeLists.txt | 4 +- .../delegate/tensorrt/cuda_impl/fse_decode.cu | 2 +- .../op/conv2dbackpropinput_tensorrt.cc | 2 +- .../delegate/tensorrt/op/deconv3d_tensorrt.cc | 2 +- .../tensorrt/op/deconvolution_tensorrt.cc | 2 +- .../delegate/tensorrt/op/resize_tensorrt.cc | 2 +- .../tensorrt/optimizer/tensorrt_optimizer.cc | 2 +- .../delegate/tensorrt/tensorrt_utils.h | 2 +- .../graph_compiler/single_graph_scheduler.cc | 2 +- mindspore-lite/src/extendrt/infer_session.cc | 2 +- .../kernel/ascend/model/dyn_shape_process.cc | 2 +- .../kernel/cpu/transpose_kernel_mod.cc | 2 +- .../kernel/cpu/transpose_kernel_mod.h | 2 +- .../src/extendrt/kernel/cuda/batchtospace.cc | 2 +- .../mindir_model/mindir_model_util.cc | 2 +- .../populate/arithmetic_populate.h | 2 +- .../base_operator_populate_register.h | 2 +- mindspore-lite/src/infer/primitive_type.cc | 2 +- mindspore-lite/src/litert/cpu_info.cc | 2 +- mindspore-lite/src/litert/cpu_info.h | 2 +- .../litert/delegate/coreml/op/coreml_op.cc | 2 +- .../src/litert/delegate/coreml/op/coreml_op.h | 2 +- .../src/litert/delegate/delegate_utils.cc | 2 +- .../src/litert/delegate/delegate_utils.h | 2 +- .../src/litert/delegate/npu/CMakeLists.txt | 2 +- .../litert/delegate/npu/npu_converter_utils.h | 2 +- .../delegate/npu/op/convolution_base_npu.cc | 2 +- .../delegate/npu/op/deconvolution_npu.cc | 2 +- .../src/litert/delegate/npu/op/npu_op.h | 2 +- .../litert/delegate/npu/transpose_kernel.cc | 2 +- mindspore-lite/src/litert/infer_manager.cc | 2 +- mindspore-lite/src/litert/infer_manager.h | 4 +- mindspore-lite/src/litert/inner_context.h | 4 +- .../kernel/ascend/src/custom_interface.cc | 2 +- .../litert/kernel/cpu/base/arithmetic_base.cc | 2 +- .../litert/kernel/cpu/base/arithmetic_base.h | 2 +- .../kernel/cpu/base/constant_of_shape.h | 6 +- .../litert/kernel/cpu/base/custom_is_inf.cc | 2 +- .../kernel/cpu/base/custom_masked_fill.cc | 2 +- .../kernel/cpu/base/custom_tensor_scatter.cc | 2 +- .../cpu/base/detection_post_process_base.cc | 2 +- .../cpu/base/detection_post_process_base.h | 2 +- .../kernel/cpu/base/format_transpose.cc | 2 +- .../litert/kernel/cpu/base/format_transpose.h | 2 +- .../kernel/cpu/base/group_convolution_base.h | 4 +- .../cpu/base/group_convolution_creator.h | 2 +- .../litert/kernel/cpu/base/layout_transform.h | 2 +- .../kernel/cpu/base/quant_dtype_cast.cc | 2 +- .../litert/kernel/cpu/base/random_normal.h | 2 +- .../src/litert/kernel/cpu/base/reduce_base.h | 2 +- .../src/litert/kernel/cpu/base/resize_base.h | 2 +- .../litert/kernel/cpu/base/scatter_nd_base.h | 2 +- .../kernel/cpu/base/scatter_nd_binary.h | 2 +- .../src/litert/kernel/cpu/base/split_base.h | 4 +- .../cpu/base/split_with_over_lap_base.cc | 2 +- .../cpu/base/split_with_over_lap_base.h | 4 +- .../litert/kernel/cpu/base/transpose_base.h | 2 +- .../kernel/cpu/bolt/bolt_parameter_manager.cc | 2 +- .../kernel/cpu/bolt/bolt_parameter_manager.h | 2 +- .../src/litert/kernel/cpu/bolt/bolt_utils.h | 2 +- .../kernel/cpu/bolt/convolution_bolt.cc | 4 +- .../litert/kernel/cpu/control/tensor_array.h | 2 +- .../cpu/control/tensorlist_fromtensor.h | 2 +- .../kernel/cpu/control/tensorlist_getitem.h | 2 +- .../kernel/cpu/control/tensorlist_reserve.h | 2 +- .../kernel/cpu/control/tensorlist_setitem.h | 2 +- .../kernel/cpu/control/tensorlist_stack.h | 2 +- .../src/litert/kernel/cpu/fp16/biasadd_fp16.h | 2 +- .../src/litert/kernel/cpu/fp16/cast_fp16.h | 6 +- .../src/litert/kernel/cpu/fp16/common_fp16.cc | 2 +- .../kernel/cpu/fp16/convolution_1x1_fp16.cc | 8 +- .../kernel/cpu/fp16/convolution_1x1_fp16.h | 4 +- .../cpu/fp16/convolution_delegate_fp16.cc | 2 +- .../cpu/fp16/convolution_delegate_fp16.h | 4 +- .../fp16/convolution_depthwise_3x3_fp16.cc | 4 +- .../cpu/fp16/convolution_depthwise_3x3_fp16.h | 2 +- .../cpu/fp16/convolution_depthwise_fp16.cc | 4 +- .../cpu/fp16/convolution_depthwise_fp16.h | 2 +- .../convolution_depthwise_slidewindow_fp16.cc | 4 +- .../convolution_depthwise_slidewindow_fp16.h | 2 +- .../kernel/cpu/fp16/convolution_fp16.cc | 10 +- .../cpu/fp16/convolution_winograd_fp16.h | 6 +- .../litert/kernel/cpu/fp16/custom_gru_fp16.cc | 8 +- .../cpu/fp16/deconvolution_depthwise_fp16.cc | 2 +- .../cpu/fp16/deconvolution_depthwise_fp16.h | 2 +- .../kernel/cpu/fp16/deconvolution_fp16.h | 4 +- .../cpu/fp16/deconvolution_winograd_fp16.h | 6 +- .../kernel/cpu/fp16/dynamic_quant_fp16.cc | 6 +- .../litert/kernel/cpu/fp16/fp16_op_handler.h | 4 +- .../kernel/cpu/fp16/group_convolution_fp16.h | 4 +- .../src/litert/kernel/cpu/fp16/gru_fp16.cc | 8 +- .../src/litert/kernel/cpu/fp16/gru_fp16.h | 2 +- .../kernel/cpu/fp16/instance_norm_fp16.cc | 6 +- .../kernel/cpu/fp16/instance_norm_fp16.h | 2 +- .../kernel/cpu/fp16/layout_transform_fp16.cc | 2 +- .../litert/kernel/cpu/fp16/lstm_fp16_base.cc | 4 +- .../litert/kernel/cpu/fp16/lstm_fp16_base.h | 2 +- .../kernel/cpu/fp16/lstm_mindir_fp16.cc | 2 +- .../kernel/cpu/fp16/lstm_non_mindir_fp16.cc | 2 +- .../kernel/cpu/fp16/matmul_base_fp16.cc | 4 +- .../litert/kernel/cpu/fp16/matmul_base_fp16.h | 2 +- .../kernel/cpu/fp16/quant_dtype_cast_fp16.cc | 4 +- .../src/litert/kernel/cpu/fp16/resize_fp16.h | 2 +- .../cpu/fp16_grad/activation_fp16_grad.cc | 2 +- .../cpu/fp16_grad/activation_fp16_grad.h | 4 +- .../cpu/fp16_grad/arithmetic_fp16_grad.cc | 2 +- .../cpu/fp16_grad/arithmetic_fp16_grad.h | 2 +- .../cpu/fp16_grad/arithmetic_fp16_self_grad.h | 2 +- .../kernel/cpu/fp16_grad/bias_fp16_grad.h | 2 +- .../kernel/cpu/fp16_grad/bn_fp16_grad.cc | 2 +- .../kernel/cpu/fp16_grad/bn_fp16_grad.h | 2 +- .../fp16_grad/convolution_fp16_grad_filter.cc | 10 +- .../fp16_grad/convolution_fp16_grad_input.cc | 8 +- .../kernel/cpu/fp16_grad/dropout_fp16_grad.cc | 4 +- .../cpu/fp16_grad/layernorm_fp16_grad.cc | 4 +- .../kernel/cpu/fp16_grad/neg_fp16_grad.cc | 2 +- .../kernel/cpu/fp16_grad/pooling_fp16_grad.cc | 4 +- .../kernel/cpu/fp16_grad/pooling_fp16_grad.h | 2 +- .../kernel/cpu/fp16_grad/resize_fp16_grad.cc | 4 +- .../cpu/fp16_grad/strided_slice_fp16_grad.cc | 2 +- .../cpu/fp16_grad/strided_slice_fp16_grad.h | 2 +- .../fp16_grad/unsorted_segment_sum_fp16.cc | 2 +- .../src/litert/kernel/cpu/fp32/adder_fp32.cc | 4 +- .../src/litert/kernel/cpu/fp32/adder_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/affine_fp32.cc | 4 +- .../src/litert/kernel/cpu/fp32/affine_fp32.h | 4 +- .../litert/kernel/cpu/fp32/all_gather_fp32.h | 2 +- .../litert/kernel/cpu/fp32/arithmetic_fp32.cc | 2 +- .../kernel/cpu/fp32/broadcast_to_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/cast_fp32.h | 4 +- .../kernel/cpu/fp32/convolution_1x1_fp32.h | 10 +- .../cpu/fp32/convolution_delegate_fp32.cc | 4 +- .../cpu/fp32/convolution_delegate_fp32.h | 6 +- .../fp32/convolution_depthwise_3x3_fp32.cc | 2 +- .../cpu/fp32/convolution_depthwise_3x3_fp32.h | 2 +- .../cpu/fp32/convolution_depthwise_fp32.cc | 4 +- .../cpu/fp32/convolution_depthwise_fp32.h | 2 +- .../convolution_depthwise_indirect_fp32.h | 2 +- .../convolution_depthwise_slidewindow_fp32.h | 2 +- ...nvolution_depthwise_slidewindow_x86_fp32.h | 2 +- .../kernel/cpu/fp32/convolution_fp32.cc | 6 +- .../litert/kernel/cpu/fp32/convolution_fp32.h | 2 +- .../cpu/fp32/convolution_im2col_arm64_fp32.cc | 2 +- .../fp32/convolution_im2col_avx512_fp32.cc | 2 +- .../cpu/fp32/convolution_im2col_avx_fp32.cc | 2 +- .../cpu/fp32/convolution_im2col_base_fp32.cc | 6 +- .../cpu/fp32/convolution_im2col_base_fp32.h | 2 +- .../cpu/fp32/convolution_im2col_fp32.cc | 2 +- .../kernel/cpu/fp32/convolution_im2col_fp32.h | 2 +- .../convolution_slidewindow_arm64_fp32.cc | 2 +- .../fp32/convolution_slidewindow_avx_fp32.cc | 4 +- .../cpu/fp32/convolution_slidewindow_fp32.cc | 4 +- .../cpu/fp32/convolution_slidewindow_fp32.h | 2 +- .../kernel/cpu/fp32/convolution_sw_1x1_fp32.h | 6 +- .../fp32/convolution_winograd_arm64_fp32.cc | 2 +- .../cpu/fp32/convolution_winograd_avx_fp32.cc | 4 +- .../fp32/convolution_winograd_base_fp32.cc | 4 +- .../cpu/fp32/convolution_winograd_base_fp32.h | 6 +- .../cpu/fp32/convolution_winograd_fp32.cc | 2 +- .../cpu/fp32/convolution_winograd_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/cumsum_fp32.cc | 2 +- .../src/litert/kernel/cpu/fp32/cumsum_fp32.h | 2 +- .../litert/kernel/cpu/fp32/custom_gru_fp32.cc | 6 +- .../cpu/fp32/deconvolution_depthwise_fp32.h | 2 +- .../kernel/cpu/fp32/deconvolution_fp32.h | 4 +- .../cpu/fp32/deconvolution_winograd_fp32.h | 4 +- .../cpu/fp32/detection_post_process_fp32.cc | 2 +- .../cpu/fp32/detection_post_process_fp32.h | 2 +- .../kernel/cpu/fp32/embedding_lookup_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/glu_fp32.cc | 4 +- .../src/litert/kernel/cpu/fp32/glu_fp32.h | 6 +- .../kernel/cpu/fp32/group_convolution_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/gru_fp32.cc | 4 +- .../src/litert/kernel/cpu/fp32/gru_fp32.h | 2 +- .../kernel/cpu/fp32/instance_norm_fp32.cc | 4 +- .../kernel/cpu/fp32/instance_norm_fp32.h | 2 +- .../cpu/fp32/invert_permutation_fp32.cc | 4 +- .../litert/kernel/cpu/fp32/l2_norm_fp32.cc | 2 +- .../src/litert/kernel/cpu/fp32/l2_norm_fp32.h | 2 +- .../litert/kernel/cpu/fp32/lstm_fp32_base.cc | 4 +- .../litert/kernel/cpu/fp32/lstm_fp32_base.h | 2 +- .../kernel/cpu/fp32/lstm_mindir_fp32.cc | 2 +- .../kernel/cpu/fp32/lstm_non_mindir_fp32.cc | 2 +- .../src/litert/kernel/cpu/fp32/matmul_fp32.cc | 4 +- .../src/litert/kernel/cpu/fp32/matmul_fp32.h | 2 +- .../kernel/cpu/fp32/matmul_fp32_arm32.cc | 4 +- .../kernel/cpu/fp32/matmul_fp32_arm64.cc | 6 +- .../litert/kernel/cpu/fp32/matmul_fp32_avx.cc | 4 +- .../kernel/cpu/fp32/matmul_fp32_avx512.cc | 8 +- .../kernel/cpu/fp32/matmul_fp32_base.cc | 6 +- .../litert/kernel/cpu/fp32/matmul_fp32_base.h | 2 +- .../litert/kernel/cpu/fp32/matmul_fp32_sse.cc | 4 +- .../cpu/fp32/non_max_suppression_fp32.cc | 2 +- .../cpu/fp32/non_max_suppression_fp32.h | 2 +- .../online_fusion/cast_gather_reduce_fp32.cc | 2 +- .../fp32/online_fusion/reduce_concat_fp32.cc | 2 +- .../online_fusion/split_reduce_concat_fp32.cc | 2 +- .../online_fusion/split_reduce_concat_fp32.h | 2 +- .../kernel/cpu/fp32/reduce_scatter_fp32.h | 2 +- .../fp32/relative_position_attention_fp32.cc | 2 +- .../fp32/relative_position_attention_fp32.h | 4 +- .../src/litert/kernel/cpu/fp32/resize_fp32.h | 2 +- .../kernel/cpu/fp32/reverse_sequence_fp32.h | 2 +- .../kernel/cpu/fp32/roi_pooling_fp32.cc | 2 +- .../litert/kernel/cpu/fp32/roi_pooling_fp32.h | 2 +- .../kernel/cpu/fp32/space_to_batch_fp32.h | 4 +- .../kernel/cpu/fp32/space_to_depth_fp32.cc | 4 +- .../kernel/cpu/fp32/space_to_depth_fp32.h | 2 +- .../cpu/fp32/sparse_fill_empty_rows_fp32.cc | 2 +- .../kernel/cpu/fp32/sparse_reshape_fp32.cc | 2 +- .../cpu/fp32/sparse_segment_sum_fp32.cc | 2 +- .../kernel/cpu/fp32/sparse_to_dense_fp32.cc | 4 +- .../kernel/cpu/fp32/sparse_to_dense_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/topk_fp32.h | 4 +- .../kernel/cpu/fp32/transpose_server_fp32.cc | 2 +- .../kernel/cpu/fp32/transpose_server_fp32.h | 2 +- .../kernel/cpu/fp32/uniform_real_fp32.h | 2 +- .../src/litert/kernel/cpu/fp32/unstack_fp32.h | 2 +- .../kernel/cpu/fp32_grad/activation_grad.cc | 2 +- .../kernel/cpu/fp32_grad/activation_grad.h | 2 +- .../src/litert/kernel/cpu/fp32_grad/adam.cc | 4 +- .../src/litert/kernel/cpu/fp32_grad/adam.h | 2 +- .../kernel/cpu/fp32_grad/adam_weight_decay.cc | 2 +- .../kernel/cpu/fp32_grad/apply_momentum.h | 2 +- .../kernel/cpu/fp32_grad/arithmetic_grad.cc | 6 +- .../kernel/cpu/fp32_grad/arithmetic_grad.h | 2 +- .../cpu/fp32_grad/arithmetic_self_grad.cc | 6 +- .../src/litert/kernel/cpu/fp32_grad/assign.h | 2 +- .../litert/kernel/cpu/fp32_grad/bias_grad.h | 2 +- .../cpu/fp32_grad/binary_cross_entropy.cc | 2 +- .../fp32_grad/binary_cross_entropy_grad.cc | 2 +- .../litert/kernel/cpu/fp32_grad/bn_grad.cc | 2 +- .../kernel/cpu/fp32_grad/convolution.cc | 6 +- .../cpu/fp32_grad/convolution_grad_filter.cc | 10 +- .../cpu/fp32_grad/convolution_grad_input.cc | 8 +- .../fp32_grad/deconvolution_grad_filter.cc | 6 +- .../litert/kernel/cpu/fp32_grad/dropout.cc | 2 +- .../kernel/cpu/fp32_grad/dropout_grad.cc | 4 +- .../kernel/cpu/fp32_grad/layernorm_grad.cc | 6 +- .../cpu/fp32_grad/lstm_grad_data_fp32.cc | 2 +- .../cpu/fp32_grad/lstm_grad_data_fp32.h | 2 +- .../kernel/cpu/fp32_grad/lstm_grad_fp32.cc | 2 +- .../kernel/cpu/fp32_grad/lstm_grad_fp32.h | 2 +- .../cpu/fp32_grad/lstm_grad_weight_fp32.cc | 2 +- .../cpu/fp32_grad/lstm_grad_weight_fp32.h | 2 +- .../litert/kernel/cpu/fp32_grad/neg_grad.cc | 2 +- .../kernel/cpu/fp32_grad/nllloss_grad.cc | 2 +- .../kernel/cpu/fp32_grad/nllloss_grad.h | 2 +- .../kernel/cpu/fp32_grad/pooling_grad.cc | 4 +- .../kernel/cpu/fp32_grad/pooling_grad.h | 2 +- .../litert/kernel/cpu/fp32_grad/power_grad.cc | 4 +- .../litert/kernel/cpu/fp32_grad/power_grad.h | 4 +- .../kernel/cpu/fp32_grad/resize_grad.cc | 6 +- .../src/litert/kernel/cpu/fp32_grad/sgd.h | 2 +- .../kernel/cpu/fp32_grad/smooth_l1_loss.h | 2 +- .../cpu/fp32_grad/smooth_l1_loss_grad.h | 2 +- .../softmax_cross_entropy_with_logits.cc | 6 +- .../softmax_cross_entropy_with_logits.h | 6 +- .../kernel/cpu/fp32_grad/softmax_grad.cc | 2 +- .../kernel/cpu/fp32_grad/softmax_grad.h | 2 +- ...parse_softmax_cross_entropy_with_logits.cc | 6 +- ...sparse_softmax_cross_entropy_with_logits.h | 6 +- .../cpu/fp32_grad/strided_slice_grad.cc | 2 +- .../kernel/cpu/fp32_grad/strided_slice_grad.h | 2 +- .../cpu/fp32_grad/unsorted_segment_sum.cc | 2 +- .../cpu/fp32_sparse/matmul_sparse_fp32.cc | 6 +- .../cpu/fp32_sparse/matmul_sparse_fp32.h | 4 +- .../src/litert/kernel/cpu/int8/add_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/add_int8.h | 4 +- .../litert/kernel/cpu/int8/argminmax_int8.h | 8 +- .../litert/kernel/cpu/int8/arithmetic_int8.cc | 2 +- .../litert/kernel/cpu/int8/arithmetic_int8.h | 2 +- .../kernel/cpu/int8/arithmetic_self_int8.h | 4 +- .../kernel/cpu/int8/batch_to_space_int8.h | 6 +- .../litert/kernel/cpu/int8/batchnorm_int8.cc | 2 +- .../litert/kernel/cpu/int8/batchnorm_int8.h | 4 +- .../src/litert/kernel/cpu/int8/concat_int8.h | 4 +- .../kernel/cpu/int8/convolution_1x1_int8.h | 8 +- .../kernel/cpu/int8/convolution_3x3_int8.cc | 2 +- .../kernel/cpu/int8/convolution_3x3_int8.h | 2 +- .../int8/convolution_depthwise_3x3_int8.cc | 2 +- .../cpu/int8/convolution_depthwise_3x3_int8.h | 2 +- .../cpu/int8/convolution_depthwise_int8.cc | 2 +- .../cpu/int8/convolution_depthwise_int8.h | 2 +- .../convolution_depthwise_slidewindow_int8.cc | 2 +- .../convolution_depthwise_slidewindow_int8.h | 2 +- .../kernel/cpu/int8/convolution_int8.cc | 2 +- .../litert/kernel/cpu/int8/convolution_int8.h | 2 +- .../cpu/int8/convolution_int8_creator.h | 2 +- .../src/litert/kernel/cpu/int8/crop_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/crop_int8.h | 2 +- .../cpu/int8/deconvolution_depthwise_int8.cc | 2 +- .../cpu/int8/deconvolution_depthwise_int8.h | 2 +- .../kernel/cpu/int8/deconvolution_int8.h | 8 +- .../kernel/cpu/int8/depth_to_space_int8.cc | 2 +- .../kernel/cpu/int8/depth_to_space_int8.h | 8 +- .../cpu/int8/detection_post_process_int8.cc | 2 +- .../cpu/int8/detection_post_process_int8.h | 2 +- .../src/litert/kernel/cpu/int8/div_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/div_int8.h | 2 +- .../kernel/cpu/int8/dynamic_gather_int8.cc | 6 +- .../kernel/cpu/int8/dynamic_gather_int8.h | 4 +- .../litert/kernel/cpu/int8/dynamic_quant.cc | 8 +- .../litert/kernel/cpu/int8/dynamic_quant.h | 2 +- .../litert/kernel/cpu/int8/gatherNd_int8.cc | 2 +- .../litert/kernel/cpu/int8/gatherNd_int8.h | 2 +- .../src/litert/kernel/cpu/int8/gather_int8.cc | 6 +- .../src/litert/kernel/cpu/int8/gather_int8.h | 4 +- .../kernel/cpu/int8/group_convolution_int8.h | 2 +- .../src/litert/kernel/cpu/int8/hswish_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/hswish_int8.h | 4 +- .../src/litert/kernel/cpu/int8/l2_norm_int8.h | 2 +- .../litert/kernel/cpu/int8/layer_norm_int8.h | 4 +- .../litert/kernel/cpu/int8/leaky_relu_int8.h | 4 +- .../litert/kernel/cpu/int8/matmul_base_int8.h | 10 +- .../cpu/int8/matmul_dynamic_base_int8.cc | 2 +- .../cpu/int8/matmul_dynamic_base_int8.h | 8 +- .../kernel/cpu/int8/matmul_dynamic_int8.cc | 4 +- .../cpu/int8/matmul_dynamic_sdot_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/matmul_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/matmul_int8.h | 4 +- .../src/litert/kernel/cpu/int8/mul_int8.h | 6 +- .../litert/kernel/cpu/int8/opt_op_handler.cc | 2 +- .../litert/kernel/cpu/int8/opt_op_handler.h | 2 +- .../src/litert/kernel/cpu/int8/pad_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/pad_int8.h | 8 +- .../litert/kernel/cpu/int8/pooling_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/pooling_int8.h | 2 +- .../src/litert/kernel/cpu/int8/power_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/power_int8.h | 4 +- .../src/litert/kernel/cpu/int8/reduce_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/reduce_int8.h | 6 +- .../src/litert/kernel/cpu/int8/relux_int8.h | 4 +- .../litert/kernel/cpu/int8/reshape_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/reshape_int8.h | 2 +- .../src/litert/kernel/cpu/int8/resize_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/resize_int8.h | 2 +- .../src/litert/kernel/cpu/int8/scale_int8.h | 8 +- .../litert/kernel/cpu/int8/sigmoid_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/sigmoid_int8.h | 2 +- .../src/litert/kernel/cpu/int8/slice_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/slice_int8.h | 6 +- .../litert/kernel/cpu/int8/softmax_int8.cc | 2 +- .../src/litert/kernel/cpu/int8/softmax_int8.h | 4 +- .../kernel/cpu/int8/space_to_batch_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/split_int8.cc | 4 +- .../src/litert/kernel/cpu/int8/squeeze_int8.h | 4 +- .../src/litert/kernel/cpu/int8/sub_int8.h | 6 +- .../src/litert/kernel/cpu/int8/tanh_int8.h | 4 +- .../src/litert/kernel/cpu/int8/topk_int8.h | 2 +- .../litert/kernel/cpu/int8/transpose_int8.cc | 4 +- .../litert/kernel/cpu/int8/unsqueeze_int8.cc | 2 +- .../litert/kernel/cpu/int8/unsqueeze_int8.h | 2 +- .../kernel/cpu/nnacl/nnacl_batchnorm.cc | 2 +- .../kernel/cpu/nnacl/nnacl_convolution.cc | 4 +- .../cpu/nnacl/nnacl_fused_batch_norm.cc | 4 +- .../litert/kernel/cpu/nnacl/nnacl_kernel.cc | 2 +- .../litert/kernel/cpu/nnacl/nnacl_kernel.h | 2 +- .../litert/kernel/cpu/nnacl/nnacl_matmul.cc | 2 +- .../litert/kernel/cpu/nnacl/nnacl_matmul.h | 2 +- .../litert/kernel/cpu/nnacl/nnacl_reduce.cc | 2 +- .../kernel/cpu/nnacl/nnacl_strided_slice.cc | 2 +- .../litert/kernel/cpu/nnacl_c/CMakeLists.txt | 293 ++ .../src/litert/kernel/cpu/nnacl_c/OWNERS | 11 + .../src/litert/kernel/cpu/nnacl_c/README.md | 1 + .../kernel/cpu/nnacl_c/activation_parameter.h | 29 + .../kernel/cpu/nnacl_c/affine_parameter.h | 32 + .../kernel/cpu/nnacl_c/all_gather_parameter.h | 30 + .../cpu/nnacl_c/arg_min_max_parameter.h | 30 + .../kernel/cpu/nnacl_c/arithmetic_parameter.h | 48 + .../cpu/nnacl_c/arithmetic_self_parameter.h | 30 + .../assembly/arm32/ConvDw3x3Int8BorderPixel.S | 128 + .../nnacl_c/assembly/arm32/ConvDwFp32Border.S | 75 + .../nnacl_c/assembly/arm32/ConvDwFp32Center.S | 176 + .../nnacl_c/assembly/arm32/ConvDwFp32Row.S | 125 + .../nnacl_c/assembly/arm32/ConvDwInt8Center.S | 290 ++ .../assembly/arm32/ConvDwInt8PostAlign4.S | 120 + .../arm32/ConvDwInt8PostAlign4PerChannel.S | 123 + .../nnacl_c/assembly/arm32/ConvDwInt8Row.S | 144 + .../assembly/arm32/DeconvDwFp32Center.S | 79 + .../assembly/arm32/DeconvDwInt8Center.S | 79 + .../nnacl_c/assembly/arm32/DeconvDwInt8Post.S | 84 + .../arm32/IndirectGemmInt16to32_8x4.S | 249 + .../assembly/arm32/IndirectGemmInt8_2x4.S | 306 ++ .../nnacl_c/assembly/arm32/MatVecMulFp32.S | 195 + .../cpu/nnacl_c/assembly/arm32/MatmulFp32.S | 381 ++ .../nnacl_c/assembly/arm32/MatmulFp32Opt.S | 422 ++ .../assembly/arm32/MatmulFp32Opt12x4.S | 578 +++ .../cpu/nnacl_c/assembly/arm32/MatmulInt8.S | 298 ++ .../nnacl_c/assembly/arm32/MatmulInt8Opt.S | 300 ++ .../assembly/arm32/MatmulWinogradFp32.S | 186 + .../assembly/arm32/PostFuncBiasReluC4.S | 248 + .../assembly/arm32/PostFuncBiasReluC8.S | 450 ++ .../assembly/arm32/PreSum4x16Int8Peroc.S | 143 + .../assembly/arm32/PreSum4x16Int8Pert.S | 94 + .../assembly/arm32/TiledC4MatmulFp32.S | 211 + .../assembly/arm32/WinogradTransLeft.S | 230 + .../assembly/arm32/WinogradTransRight.S | 220 + .../cpu/nnacl_c/assembly/arm64/AdderFp32.S | 622 +++ .../nnacl_c/assembly/arm64/BigMatmulFp32Opt.S | 2528 ++++++++++ .../assembly/arm64/ConvDw3x3Fp32Corner.S | 114 + .../assembly/arm64/ConvDw3x3Fp32Horizontal.S | 130 + .../assembly/arm64/ConvDw3x3Fp32Stride1.S | 210 + .../assembly/arm64/ConvDw3x3Fp32Stride2.S | 212 + .../assembly/arm64/ConvDw3x3Fp32Vertical.S | 126 + .../nnacl_c/assembly/arm64/ConvDw3x3Int8.S | 500 ++ .../assembly/arm64/ConvDw3x3Int8Corner.S | 222 + .../assembly/arm64/ConvDw3x3Int8Horizontal.S | 255 + .../assembly/arm64/ConvDw3x3Int8Stride2.S | 474 ++ .../assembly/arm64/ConvDw3x3Int8Vertical.S | 245 + .../nnacl_c/assembly/arm64/ConvDw3x3Line.S | 203 + .../nnacl_c/assembly/arm64/ConvDwFp32Border.S | 68 + .../nnacl_c/assembly/arm64/ConvDwFp32Center.S | 313 ++ .../assembly/arm64/ConvDwFp32Indirect3x3.S | 159 + .../assembly/arm64/ConvDwFp32Indirect5x5.S | 304 ++ .../nnacl_c/assembly/arm64/ConvDwFp32Row.S | 129 + .../nnacl_c/assembly/arm64/ConvDwInt8Center.S | 294 ++ .../assembly/arm64/ConvDwInt8PostAlign4.S | 191 + .../arm64/ConvDwInt8PostAlign4PerChannel.S | 119 + .../nnacl_c/assembly/arm64/ConvDwInt8Row.S | 134 + .../nnacl_c/assembly/arm64/ConvFp32Center.S | 458 ++ .../nnacl_c/assembly/arm64/ConvSW1x16Kernel.S | 421 ++ .../nnacl_c/assembly/arm64/ConvSW1x8Kernel.S | 278 ++ .../nnacl_c/assembly/arm64/ConvSW2x16Kernel.S | 407 ++ .../nnacl_c/assembly/arm64/ConvSW2x8Kernel.S | 265 + .../nnacl_c/assembly/arm64/ConvSW3x16Kernel.S | 533 ++ .../nnacl_c/assembly/arm64/ConvSW3x8Kernel.S | 332 ++ .../nnacl_c/assembly/arm64/ConvSW4x16Kernel.S | 662 +++ .../nnacl_c/assembly/arm64/ConvSW4x8Kernel.S | 406 ++ .../nnacl_c/assembly/arm64/ConvSW5x16Kernel.S | 457 ++ .../nnacl_c/assembly/arm64/ConvSW5x8Kernel.S | 308 ++ .../assembly/arm64/DeconvDwFp32Border.S | 56 + .../assembly/arm64/DeconvDwFp32Center.S | 75 + .../assembly/arm64/DeconvDwInt8Center.S | 75 + .../nnacl_c/assembly/arm64/DeconvDwInt8Post.S | 66 + .../assembly/arm64/DynamicGatherArm64.S | 48 + .../arm64/IndirectGemmInt16to32_8x4.S | 233 + .../nnacl_c/assembly/arm64/MatVecMulFp32.S | 252 + .../assembly/arm64/MatVecMulPackFp32.S | 198 + .../cpu/nnacl_c/assembly/arm64/MatmulFp32.S | 787 +++ .../nnacl_c/assembly/arm64/MatmulFp32Opt.S | 1669 +++++++ .../assembly/arm64/MatmulFp32OptRow12.S | 1229 +++++ .../assembly/arm64/MatmulFp32OptRow4.S | 597 +++ .../assembly/arm64/MatmulFp32OptRow8.S | 911 ++++ .../cpu/nnacl_c/assembly/arm64/MatmulInt8.S | 420 ++ .../nnacl_c/assembly/arm64/MatmulInt8Opt.S | 356 ++ .../cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S | 193 + .../assembly/arm64/MatmulWinogradFp32.S | 183 + .../assembly/arm64/PostFuncBiasReluC4.S | 316 ++ .../assembly/arm64/PostFuncBiasReluC8.S | 553 +++ .../assembly/arm64/PostFuncInt8C4Neon64.S | 259 + .../assembly/arm64/PreSum4x16Int8Peroc.S | 140 + .../assembly/arm64/PreSum4x16Int8Pert.S | 81 + .../cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S | 294 ++ .../assembly/arm64/TiledC4MatmulFp32.S | 279 ++ .../assembly/arm64/WinogradTransLeft.S | 158 + .../assembly/arm64/WinogradTransRight.S | 160 + .../arm82_aarch32_fp16/Float16Tofloat32.S | 70 + .../arm82_aarch32_fp16/Float32ToFloat16.S | 70 + .../arm82_aarch32_fp16/MatVecMulFp16.S | 237 + .../arm82_aarch32_fp16/Matmul12x8Fp16.S | 617 +++ .../arm82_aarch32_fp16/TiledC4MatmulFp16.S | 108 + .../arm82_aarch32_fp16/WinogradTransLeft.S | 165 + .../arm82_aarch32_fp16/WinogradTransRight.S | 163 + .../nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S | 313 ++ .../assembly/avx/ConvDwFp32BorderAvx.S | 188 + .../nnacl_c/assembly/avx/ConvDwFp32RowAvx.S | 189 + .../assembly/avx/ConvDwFp32RowOptAVX.S | 382 ++ .../cpu/nnacl_c/assembly/avx/MatmulAvx.S | 993 ++++ .../assembly/avx512/ConvDwFp32RowAVX512.S | 499 ++ .../assembly/fp16/CalculateMinMaxFp16Count8.S | 56 + .../nnacl_c/assembly/fp16/ConvDwFp16Border.S | 68 + .../nnacl_c/assembly/fp16/ConvDwFp16Center.S | 312 ++ .../cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S | 129 + .../assembly/fp16/DeconvDwFp16Border.S | 51 + .../assembly/fp16/DeconvDwFp16Center.S | 75 + .../assembly/fp16/DynamicGatherArm64ForFp16.S | 54 + .../nnacl_c/assembly/fp16/Float16ToFloat32.S | 68 + .../nnacl_c/assembly/fp16/Float32ToFloat16.S | 68 + .../cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S | 191 + .../nnacl_c/assembly/fp16/Matmul12X16Fp16.S | 1703 +++++++ .../assembly/fp16/MatmulBaseFp16Neon.S | 960 ++++ .../cpu/nnacl_c/assembly/fp16/MatmulFp16.S | 892 ++++ .../cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S | 1185 +++++ .../nnacl_c/assembly/fp16/MatmulFp16OptV2.S | 2966 ++++++++++++ .../assembly/fp16/MatmulWinogradFp16.S | 217 + .../assembly/fp16/PostFuncBiasReluC4Fp16.S | 293 ++ .../assembly/fp16/PostFuncBiasReluC8Fp16.S | 469 ++ .../nnacl_c/assembly/fp16/TiledC4MatmulFp16.S | 273 ++ .../cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S | 181 + .../assembly/fp16/WinogradTransLeftFp16.S | 150 + .../assembly/fp16/WinogradTransRightFp16.S | 154 + .../opt/DynamicMatmulSdot4x4x16AIWI.S | 764 +++ .../opt/DynamicMatmulSdot4x4x16AIWIForFp16.S | 788 +++ .../cpu/nnacl_c/assembly/opt/MatmulDpInt8.S | 864 ++++ .../nnacl_c/assembly/opt/MatmulDpInt8Opt.S | 1098 +++++ .../nnacl_c/assembly/opt/MatmulOptR4Int8.S | 155 + .../kernel/cpu/nnacl_c/assembly_global.h | 50 + .../kernel/cpu/nnacl_c/attention_parameter.h | 47 + .../kernel/cpu/nnacl_c/base/arithmetic_base.c | 48 + .../kernel/cpu/nnacl_c/base/arithmetic_base.h | 36 + .../cpu/nnacl_c/base/batch_to_space_base.c | 95 + .../cpu/nnacl_c/base/batch_to_space_base.h | 33 + .../kernel/cpu/nnacl_c/base/broadcast_to.c | 106 + .../kernel/cpu/nnacl_c/base/broadcast_to.h | 34 + .../kernel/cpu/nnacl_c/base/cast_base.c | 199 + .../kernel/cpu/nnacl_c/base/cast_base.h | 74 + .../cpu/nnacl_c/base/cast_base_simd.h.in | 49 + .../kernel/cpu/nnacl_c/base/concat_base.c | 54 + .../kernel/cpu/nnacl_c/base/concat_base.h | 32 + .../kernel/cpu/nnacl_c/base/conv1x1_base.c | 40 + .../kernel/cpu/nnacl_c/base/conv1x1_base.h | 32 + .../cpu/nnacl_c/base/conv_common_base.c | 128 + .../cpu/nnacl_c/base/conv_common_base.h | 41 + .../kernel/cpu/nnacl_c/base/crop_base.c | 40 + .../kernel/cpu/nnacl_c/base/crop_base.h | 35 + .../cpu/nnacl_c/base/depth_to_space_base.c | 72 + .../cpu/nnacl_c/base/depth_to_space_base.h | 31 + .../kernel/cpu/nnacl_c/base/fill_base.c | 59 + .../kernel/cpu/nnacl_c/base/fill_base.h | 33 + .../cpu/nnacl_c/base/fill_base_simd.h.in | 45 + .../cpu/nnacl_c/base/format_transpose.c | 81 + .../cpu/nnacl_c/base/format_transpose.h | 30 + .../kernel/cpu/nnacl_c/base/gather_base.c | 44 + .../kernel/cpu/nnacl_c/base/gather_base.h | 32 + .../kernel/cpu/nnacl_c/base/gather_d_base.c | 163 + .../kernel/cpu/nnacl_c/base/gather_d_base.h | 55 + .../base/minimal_filtering_generator.c | 342 ++ .../base/minimal_filtering_generator.h | 58 + .../cpu/nnacl_c/base/scatter_nd_binary.c | 111 + .../cpu/nnacl_c/base/scatter_nd_binary.h | 37 + .../nnacl_c/base/scatter_nd_binary_simd.h.in | 59 + .../cpu/nnacl_c/base/sequence_unstack_base.h | 32 + .../kernel/cpu/nnacl_c/base/slice_base.c | 173 + .../kernel/cpu/nnacl_c/base/slice_base.h | 36 + .../cpu/nnacl_c/base/space_to_depth_base.c | 54 + .../cpu/nnacl_c/base/space_to_depth_base.h | 31 + .../kernel/cpu/nnacl_c/base/split_base.c | 57 + .../kernel/cpu/nnacl_c/base/split_base.h | 32 + .../nnacl_c/base/split_with_over_lap_base.c | 38 + .../nnacl_c/base/split_with_over_lap_base.h | 33 + .../kernel/cpu/nnacl_c/base/stack_base.c | 26 + .../kernel/cpu/nnacl_c/base/stack_base.h | 30 + .../kernel/cpu/nnacl_c/base/tile_base.c | 68 + .../kernel/cpu/nnacl_c/base/tile_base.h | 32 + .../kernel/cpu/nnacl_c/base/transpose_base.c | 274 ++ .../kernel/cpu/nnacl_c/base/transpose_base.h | 69 + .../nnacl_c/base/unsorted_segment_sum_base.c | 45 + .../nnacl_c/base/unsorted_segment_sum_base.h | 38 + .../kernel/cpu/nnacl_c/base/unstack_base.c | 33 + .../kernel/cpu/nnacl_c/base/unstack_base.h | 32 + .../cpu/nnacl_c/batch_to_space_parameter.h | 30 + .../kernel/cpu/nnacl_c/batchnorm_parameter.h | 29 + .../cpu/nnacl_c/broadcast_to_parameter.h | 34 + .../kernel/cpu/nnacl_c/call_parameter.h | 28 + .../kernel/cpu/nnacl_c/clip_parameter.h | 29 + .../litert/kernel/cpu/nnacl_c/common_func.c | 35 + .../litert/kernel/cpu/nnacl_c/common_func.h | 61 + .../kernel/cpu/nnacl_c/concat_parameter.h | 29 + .../cpu/nnacl_c/constant_of_shape_parameter.h | 32 + .../kernel/cpu/nnacl_c/conv3d_parameter.h | 26 + .../kernel/cpu/nnacl_c/conv_parameter.h | 169 + .../kernel/cpu/nnacl_c/crop_parameter.h | 30 + .../kernel/cpu/nnacl_c/cumsum_parameter.h | 29 + .../kernel/cpu/nnacl_c/custom_gru_parameter.h | 31 + .../cpu/nnacl_c/custom_is_inf_parameter.h | 26 + .../nnacl_c/custom_masked_fill_parameter.h | 26 + .../kernel/cpu/nnacl_c/custom_parameter.h | 30 + .../cpu/nnacl_c/depth_to_space_parameter.h | 26 + .../detection_post_process_parameter.h | 48 + .../cpu/nnacl_c/dynamic_quant_parameter.h | 29 + .../src/litert/kernel/cpu/nnacl_c/errorcode.c | 46 + .../src/litert/kernel/cpu/nnacl_c/errorcode.h | 208 + .../litert/kernel/cpu/nnacl_c/exp_parameter.h | 28 + ...nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c | 533 ++ ...nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c | 781 +++ ...nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c | 573 +++ ...nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c | 844 ++++ ...nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c | 614 +++ ...nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c | 908 ++++ .../nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c | 158 + .../nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c | 198 + .../nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c | 238 + .../nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c | 278 ++ .../nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c | 318 ++ .../nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c | 358 ++ .../nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c | 198 + .../nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c | 261 + .../nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c | 324 ++ .../nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c | 387 ++ .../nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c | 450 ++ .../nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c | 513 ++ .../nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c | 238 + .../nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c | 324 ++ .../nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c | 410 ++ .../nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c | 496 ++ .../nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c | 583 +++ .../nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c | 669 +++ .../nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c | 283 ++ .../nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c | 392 ++ .../nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c | 501 ++ .../nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c | 611 +++ .../nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c | 720 +++ .../nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c | 830 ++++ .../nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c | 323 ++ .../nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c | 455 ++ .../nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c | 588 +++ .../nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c | 720 +++ .../nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c | 853 ++++ .../nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c | 363 ++ .../nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c | 518 ++ .../nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c | 674 +++ .../nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c | 830 ++++ .../nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c | 408 ++ .../nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c | 587 +++ .../nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c | 765 +++ .../nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c | 448 ++ .../nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c | 650 +++ .../nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c | 852 ++++ .../nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c | 488 ++ .../nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c | 713 +++ .../nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c | 297 ++ ...acl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c | 303 ++ .../nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c | 321 ++ ...acl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c | 325 ++ .../nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c | 345 ++ ...acl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c | 347 ++ .../nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c | 105 + ...acl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c | 127 + .../nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c | 129 + ...acl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c | 149 + .../nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c | 153 + ...acl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c | 173 + .../nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c | 81 + ...nacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c | 105 + .../nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c | 145 + ...acl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c | 163 + .../nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c | 185 + ...acl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c | 199 + .../nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c | 225 + ...acl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c | 237 + .../nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c | 105 + ...nacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c | 127 + .../nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c | 185 + ...acl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c | 199 + .../nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c | 241 + ...acl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c | 249 + .../nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c | 297 ++ ...acl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c | 301 ++ .../nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c | 129 + ...nacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c | 149 + .../nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c | 225 + ...acl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c | 235 + .../nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c | 297 ++ ...acl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c | 299 ++ .../nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c | 153 + ...nacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c | 171 + .../nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c | 265 + ...acl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c | 271 ++ .../nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c | 177 + ...nacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c | 193 + .../nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c | 305 ++ ...acl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c | 307 ++ .../nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c | 201 + ...nacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c | 215 + .../nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c | 225 + ...nacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c | 237 + .../nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c | 249 + ...nacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c | 259 + .../nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c | 273 ++ ...nacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c | 281 ++ ..._gemm_avx512_10x16_mask_kernel_nhwc_fp32.c | 536 ++ ..._gemm_avx512_10x32_mask_kernel_nhwc_fp32.c | 784 +++ ..._gemm_avx512_11x16_mask_kernel_nhwc_fp32.c | 577 +++ ..._gemm_avx512_11x32_mask_kernel_nhwc_fp32.c | 847 ++++ ..._gemm_avx512_12x16_mask_kernel_nhwc_fp32.c | 617 +++ ..._gemm_avx512_12x32_mask_kernel_nhwc_fp32.c | 911 ++++ ...l_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c | 161 + ...l_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c | 201 + ...l_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c | 241 + ...l_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c | 281 ++ ...l_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c | 321 ++ ...l_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c | 361 ++ ...l_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c | 201 + ...l_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c | 264 + ...l_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c | 327 ++ ...l_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c | 390 ++ ...l_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c | 453 ++ ...l_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c | 517 ++ ...l_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c | 241 + ...l_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c | 327 ++ ...l_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c | 413 ++ ...l_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c | 500 ++ ...l_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c | 586 +++ ...l_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c | 672 +++ ...l_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c | 286 ++ ...l_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c | 395 ++ ...l_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c | 505 ++ ...l_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c | 614 +++ ...l_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c | 723 +++ ...l_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c | 833 ++++ ...l_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c | 326 ++ ...l_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c | 458 ++ ...l_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c | 591 +++ ...l_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c | 723 +++ ...l_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c | 856 ++++ ...l_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c | 366 ++ ...l_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c | 522 ++ ...l_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c | 677 +++ ...l_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c | 833 ++++ ...l_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c | 410 ++ ...l_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c | 589 +++ ...l_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c | 767 +++ ...l_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c | 450 ++ ...l_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c | 652 +++ ...l_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c | 854 ++++ ...l_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c | 490 ++ ...l_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c | 715 +++ .../HPC-generator/generate_hpc.sh | 88 + .../experimental/HPC-generator/generator.py | 162 + .../gemm_avx512_mask_nhwc_asm.c.in | 263 + .../template_file/gemm_avx512_nhwc_asm.c.in | 231 + .../template_file/gemm_fma_nc8hw8.c.in | 85 + .../template_file/gemm_fma_nc8hw8_asm.c.in | 149 + .../kernel/cpu/nnacl_c/fill_parameter.h | 25 + .../kernel/cpu/nnacl_c/flatten_parameter.h | 27 + .../cpu/nnacl_c/format_transpose_parameter.h | 29 + .../kernel/cpu/nnacl_c/fp16/activation_fp16.c | 319 ++ .../kernel/cpu/nnacl_c/fp16/activation_fp16.h | 43 + .../cpu/nnacl_c/fp16/arg_min_max_fp16.c | 273 ++ .../cpu/nnacl_c/fp16/arg_min_max_fp16.h | 33 + .../kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c | 1314 +++++ .../kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h | 124 + .../cpu/nnacl_c/fp16/arithmetic_self_fp16.c | 124 + .../cpu/nnacl_c/fp16/arithmetic_self_fp16.h | 57 + .../kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c | 112 + .../kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h | 36 + .../kernel/cpu/nnacl_c/fp16/cast_fp16.h | 94 + .../cpu/nnacl_c/fp16/common_func_fp16.c | 64 + .../cpu/nnacl_c/fp16/common_func_fp16.h | 40 + .../cpu/nnacl_c/fp16/constant_of_shape_fp16.h | 38 + .../cpu/nnacl_c/fp16/conv_depthwise_fp16.c | 842 ++++ .../cpu/nnacl_c/fp16/conv_depthwise_fp16.h | 65 + .../kernel/cpu/nnacl_c/fp16/conv_fp16.c | 334 ++ .../kernel/cpu/nnacl_c/fp16/conv_fp16.h | 60 + .../kernel/cpu/nnacl_c/fp16/crop_fp16.c | 155 + .../kernel/cpu/nnacl_c/fp16/crop_fp16.h | 26 + .../kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c | 70 + .../kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h | 32 + .../kernel/cpu/nnacl_c/fp16/deconv_fp16.c | 129 + .../kernel/cpu/nnacl_c/fp16/deconv_fp16.h | 36 + .../cpu/nnacl_c/fp16/deconv_winograd_fp16.c | 519 ++ .../cpu/nnacl_c/fp16/deconv_winograd_fp16.h | 48 + .../cpu/nnacl_c/fp16/dynamic_quant_fp16.c | 42 + .../cpu/nnacl_c/fp16/dynamic_quant_fp16.h | 35 + .../litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c | 88 + .../litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h | 64 + .../kernel/cpu/nnacl_c/fp16/fill_fp16.c | 24 + .../kernel/cpu/nnacl_c/fp16/fill_fp16.h | 34 + .../litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c | 148 + .../litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h | 30 + .../cpu/nnacl_c/fp16/instance_norm_fp16.c | 217 + .../cpu/nnacl_c/fp16/instance_norm_fp16.h | 32 + .../kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c | 110 + .../kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h | 33 + .../cpu/nnacl_c/fp16/log_softmax_fp16.c | 88 + .../cpu/nnacl_c/fp16/log_softmax_fp16.h | 35 + .../kernel/cpu/nnacl_c/fp16/lstm_fp16.c | 367 ++ .../kernel/cpu/nnacl_c/fp16/lstm_fp16.h | 54 + .../kernel/cpu/nnacl_c/fp16/matmul_fp16.c | 1204 +++++ .../kernel/cpu/nnacl_c/fp16/matmul_fp16.h | 128 + .../kernel/cpu/nnacl_c/fp16/matrix_fp16.c | 83 + .../kernel/cpu/nnacl_c/fp16/matrix_fp16.h | 36 + .../kernel/cpu/nnacl_c/fp16/one_hot_fp16.c | 50 + .../kernel/cpu/nnacl_c/fp16/one_hot_fp16.h | 34 + .../kernel/cpu/nnacl_c/fp16/pack_fp16.c | 933 ++++ .../kernel/cpu/nnacl_c/fp16/pack_fp16.h | 93 + .../litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c | 48 + .../litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h | 32 + .../kernel/cpu/nnacl_c/fp16/pooling_fp16.c | 305 ++ .../kernel/cpu/nnacl_c/fp16/pooling_fp16.h | 36 + .../kernel/cpu/nnacl_c/fp16/power_fp16.c | 117 + .../kernel/cpu/nnacl_c/fp16/power_fp16.h | 64 + .../kernel/cpu/nnacl_c/fp16/prelu_fp16.c | 146 + .../kernel/cpu/nnacl_c/fp16/prelu_fp16.h | 31 + .../cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c | 290 ++ .../cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h | 35 + .../cpu/nnacl_c/fp16/ragged_range_fp16.c | 32 + .../cpu/nnacl_c/fp16/ragged_range_fp16.h | 26 + .../kernel/cpu/nnacl_c/fp16/range_fp16.h | 27 + .../kernel/cpu/nnacl_c/fp16/reduce_fp16.c | 198 + .../kernel/cpu/nnacl_c/fp16/reduce_fp16.h | 41 + .../kernel/cpu/nnacl_c/fp16/resize_fp16.c | 380 ++ .../kernel/cpu/nnacl_c/fp16/resize_fp16.h | 56 + .../kernel/cpu/nnacl_c/fp16/scale_fp16.c | 226 + .../kernel/cpu/nnacl_c/fp16/scale_fp16.h | 38 + .../kernel/cpu/nnacl_c/fp16/softmax_fp16.c | 134 + .../kernel/cpu/nnacl_c/fp16/softmax_fp16.h | 35 + .../cpu/nnacl_c/fp16/sparse_to_dense_fp16.c | 78 + .../cpu/nnacl_c/fp16/sparse_to_dense_fp16.h | 31 + .../kernel/cpu/nnacl_c/fp16/splice_fp16.c | 30 + .../kernel/cpu/nnacl_c/fp16/splice_fp16.h | 31 + .../kernel/cpu/nnacl_c/fp16/topk_fp16.c | 70 + .../kernel/cpu/nnacl_c/fp16/topk_fp16.h | 35 + .../kernel/cpu/nnacl_c/fp16/transpose_fp16.c | 257 + .../kernel/cpu/nnacl_c/fp16/transpose_fp16.h | 35 + .../kernel/cpu/nnacl_c/fp16/unique_fp16.c | 38 + .../kernel/cpu/nnacl_c/fp16/unique_fp16.h | 29 + .../kernel/cpu/nnacl_c/fp16/utils_fp16.c | 37 + .../kernel/cpu/nnacl_c/fp16/utils_fp16.h | 25 + .../kernel/cpu/nnacl_c/fp16/where_fp16.c | 34 + .../kernel/cpu/nnacl_c/fp16/where_fp16.h | 32 + .../nnacl_c/fp16/winograd_transform_fp16.c | 360 ++ .../nnacl_c/fp16/winograd_transform_fp16.h | 57 + .../cpu/nnacl_c/fp16/winograd_utils_fp16.c | 3278 +++++++++++++ .../cpu/nnacl_c/fp16/winograd_utils_fp16.h | 163 + .../nnacl_c/fp16/winograd_utils_fp16_macro.h | 437 ++ .../nnacl_c/fp16_grad/activation_grad_fp16.c | 151 + .../nnacl_c/fp16_grad/activation_grad_fp16.h | 44 + .../cpu/nnacl_c/fp16_grad/arithmetic_grad.c | 158 + .../cpu/nnacl_c/fp16_grad/arithmetic_grad.h | 41 + .../nnacl_c/fp16_grad/arithmetic_self_grad.c | 37 + .../nnacl_c/fp16_grad/arithmetic_self_grad.h | 39 + .../kernel/cpu/nnacl_c/fp16_grad/batch_norm.c | 88 + .../kernel/cpu/nnacl_c/fp16_grad/batch_norm.h | 40 + .../fp16_grad/convolution_grad_filter.c | 361 ++ .../fp16_grad/convolution_grad_filter.h | 33 + .../fp16_grad/convolution_grad_input.c | 332 ++ .../fp16_grad/convolution_grad_input.h | 33 + .../cpu/nnacl_c/fp16_grad/dropout_grad.c | 24 + .../cpu/nnacl_c/fp16_grad/dropout_grad.h | 32 + .../kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c | 385 ++ .../kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h | 46 + .../cpu/nnacl_c/fp16_grad/layernorm_grad.c | 60 + .../cpu/nnacl_c/fp16_grad/layernorm_grad.h | 32 + .../cpu/nnacl_c/fp16_grad/pack_fp16_ext.c | 201 + .../cpu/nnacl_c/fp16_grad/pack_fp16_ext.h | 37 + .../cpu/nnacl_c/fp16_grad/pooling_grad.c | 192 + .../cpu/nnacl_c/fp16_grad/pooling_grad.h | 34 + .../cpu/nnacl_c/fp16_grad/resize_grad.c | 146 + .../cpu/nnacl_c/fp16_grad/resize_grad.h | 45 + .../nnacl_c/fp16_grad/strided_slice_grad.c | 67 + .../nnacl_c/fp16_grad/strided_slice_grad.h | 31 + .../nnacl_c/fp16_grad/unsorted_segment_sum.c | 34 + .../nnacl_c/fp16_grad/unsorted_segment_sum.h | 31 + .../kernel/cpu/nnacl_c/fp32/activation_fp32.c | 292 ++ .../kernel/cpu/nnacl_c/fp32/activation_fp32.h | 50 + .../nnacl_c/fp32/activation_fp32_simd.h.in | 289 ++ .../kernel/cpu/nnacl_c/fp32/adam_fp32.c | 239 + .../kernel/cpu/nnacl_c/fp32/adam_fp32.h | 49 + .../cpu/nnacl_c/fp32/adam_fp32_simd.h.in | 203 + .../litert/kernel/cpu/nnacl_c/fp32/add_fp32.c | 156 + .../litert/kernel/cpu/nnacl_c/fp32/add_fp32.h | 47 + .../cpu/nnacl_c/fp32/add_fp32_simd.h.in | 153 + .../kernel/cpu/nnacl_c/fp32/adder_fp32.c | 93 + .../kernel/cpu/nnacl_c/fp32/adder_fp32.h | 47 + .../cpu/nnacl_c/fp32/arg_min_max_fp32.c | 298 ++ .../cpu/nnacl_c/fp32/arg_min_max_fp32.h | 34 + .../nnacl_c/fp32/arithmetic_compare_fp32.c | 198 + .../nnacl_c/fp32/arithmetic_compare_fp32.h | 77 + .../kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c | 482 ++ .../kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h | 86 + .../nnacl_c/fp32/arithmetic_fp32_simd.h.in | 287 ++ .../cpu/nnacl_c/fp32/arithmetic_self_fp32.c | 230 + .../cpu/nnacl_c/fp32/arithmetic_self_fp32.h | 70 + .../fp32/arithmetic_self_fp32_simd.h.in | 152 + .../kernel/cpu/nnacl_c/fp32/attention_fp32.c | 581 +++ .../kernel/cpu/nnacl_c/fp32/attention_fp32.h | 72 + .../kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c | 129 + .../kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h | 40 + .../cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in | 60 + .../nnacl_c/fp32/bce_with_logits_loss_fp32.h | 29 + .../fp32/bce_with_logits_loss_fp32_simd.h.in | 62 + .../nnacl_c/fp32/bce_with_loigts_loss_fp32.c | 45 + .../litert/kernel/cpu/nnacl_c/fp32/bias_add.c | 123 + .../litert/kernel/cpu/nnacl_c/fp32/bias_add.h | 34 + .../cpu/nnacl_c/fp32/bias_add_simd.h.in | 57 + .../kernel/cpu/nnacl_c/fp32/cdist_fp32.c | 77 + .../kernel/cpu/nnacl_c/fp32/cdist_fp32.h | 35 + .../cpu/nnacl_c/fp32/cdist_fp32_simd.h.in | 63 + .../cpu/nnacl_c/fp32/common_func_fp32.c | 117 + .../cpu/nnacl_c/fp32/common_func_fp32.h | 106 + .../cpu/nnacl_c/fp32/constant_of_shape_fp32.h | 52 + .../cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c | 1608 ++++++ .../cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h | 40 + .../cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h | 21 + .../cpu/nnacl_c/fp32/conv_common_fp32.c | 435 ++ .../cpu/nnacl_c/fp32/conv_common_fp32.h | 60 + .../nnacl_c/fp32/conv_depthwise_avx_fp32.c | 93 + .../nnacl_c/fp32/conv_depthwise_avx_fp32.h | 37 + .../cpu/nnacl_c/fp32/conv_depthwise_fp32.c | 2074 ++++++++ .../cpu/nnacl_c/fp32/conv_depthwise_fp32.h | 148 + .../nnacl_c/fp32/conv_im2col_avx512_fp32.c | 92 + .../nnacl_c/fp32/conv_im2col_avx512_fp32.h | 38 + .../cpu/nnacl_c/fp32/conv_im2col_fp32.c | 65 + .../cpu/nnacl_c/fp32/conv_im2col_fp32.h | 33 + .../litert/kernel/cpu/nnacl_c/fp32/conv_sw.h | 131 + .../cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c | 99 + .../cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h | 33 + .../cpu/nnacl_c/fp32/conv_sw_avx_fp32.c | 1231 +++++ .../cpu/nnacl_c/fp32/conv_sw_avx_fp32.h | 42 + .../cpu/nnacl_c/fp32/conv_winograd_fp32.c | 265 + .../cpu/nnacl_c/fp32/conv_winograd_fp32.h | 48 + .../kernel/cpu/nnacl_c/fp32/crop_fp32.c | 94 + .../kernel/cpu/nnacl_c/fp32/crop_fp32.h | 34 + .../kernel/cpu/nnacl_c/fp32/cumsum_fp32.c | 200 + .../kernel/cpu/nnacl_c/fp32/cumsum_fp32.h | 32 + .../cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in | 114 + .../kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c | 72 + .../kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h | 32 + .../kernel/cpu/nnacl_c/fp32/deconv_fp32.c | 109 + .../kernel/cpu/nnacl_c/fp32/deconv_fp32.h | 37 + .../cpu/nnacl_c/fp32/deconv_winograd_fp32.c | 733 +++ .../cpu/nnacl_c/fp32/deconv_winograd_fp32.h | 46 + .../fp32/detection_post_process_fp32.c | 235 + .../fp32/detection_post_process_fp32.h | 60 + .../litert/kernel/cpu/nnacl_c/fp32/div_fp32.c | 136 + .../litert/kernel/cpu/nnacl_c/fp32/div_fp32.h | 43 + .../cpu/nnacl_c/fp32/div_fp32_simd.h.in | 160 + .../kernel/cpu/nnacl_c/fp32/dropout_fp32.c | 28 + .../kernel/cpu/nnacl_c/fp32/dropout_fp32.h | 28 + .../cpu/nnacl_c/fp32/dropout_fp32_simd.h.in | 39 + .../cpu/nnacl_c/fp32/embedding_lookup_fp32.c | 62 + .../cpu/nnacl_c/fp32/embedding_lookup_fp32.h | 43 + .../litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c | 62 + .../litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h | 35 + .../cpu/nnacl_c/fp32/exp_fp32_simd.h.in | 56 + .../kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c | 28 + .../kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h | 30 + .../kernel/cpu/nnacl_c/fp32/group_norm_fp32.c | 125 + .../kernel/cpu/nnacl_c/fp32/group_norm_fp32.h | 35 + .../nnacl_c/fp32/group_norm_fp32_simd.h.in | 70 + .../litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c | 154 + .../litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h | 30 + .../cpu/nnacl_c/fp32/instance_norm_fp32.c | 374 ++ .../cpu/nnacl_c/fp32/instance_norm_fp32.h | 50 + .../nnacl_c/fp32/invert_permutation_fp32.c | 31 + .../nnacl_c/fp32/invert_permutation_fp32.h | 30 + .../kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c | 78 + .../kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h | 34 + .../kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c | 93 + .../kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h | 33 + .../nnacl_c/fp32/layer_norm_fp32_simd.h.in | 61 + .../nnacl_c/fp32/local_response_norm_fp32.c | 71 + .../nnacl_c/fp32/local_response_norm_fp32.h | 26 + .../cpu/nnacl_c/fp32/log_softmax_fp32.c | 85 + .../cpu/nnacl_c/fp32/log_softmax_fp32.h | 31 + .../kernel/cpu/nnacl_c/fp32/lstm_fp32.c | 328 ++ .../kernel/cpu/nnacl_c/fp32/lstm_fp32.h | 55 + .../cpu/nnacl_c/fp32/matmul_avx512_fp32.c | 248 + .../cpu/nnacl_c/fp32/matmul_avx512_fp32.h | 198 + .../nnacl_c/fp32/matmul_avx512_mask_fp32.c | 236 + .../nnacl_c/fp32/matmul_avx512_mask_fp32.h | 209 + .../kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c | 954 ++++ .../kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h | 63 + .../kernel/cpu/nnacl_c/fp32/matmul_fp32.c | 822 ++++ .../kernel/cpu/nnacl_c/fp32/matmul_fp32.h | 99 + .../cpu/nnacl_c/fp32/matmul_fp32_simd.h.in | 148 + .../litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c | 187 + .../litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h | 45 + .../cpu/nnacl_c/fp32/mul_fp32_simd.h.in | 211 + .../kernel/cpu/nnacl_c/fp32/nllloss_fp32.c | 49 + .../kernel/cpu/nnacl_c/fp32/nllloss_fp32.h | 30 + .../nnacl_c/fp32/non_max_suppression_fp32.c | 205 + .../nnacl_c/fp32/non_max_suppression_fp32.h | 25 + .../kernel/cpu/nnacl_c/fp32/one_hot_fp32.c | 51 + .../kernel/cpu/nnacl_c/fp32/one_hot_fp32.h | 33 + .../online_fusion/cast_gather_reduce_fp32.c | 69 + .../online_fusion/cast_gather_reduce_fp32.h | 37 + .../cast_gather_reduce_fp32_simd.h.in | 65 + .../fp32/online_fusion/reduce_concat_fp32.c | 124 + .../fp32/online_fusion/reduce_concat_fp32.h | 34 + .../reduce_concat_fp32_simd.h.in | 115 + .../online_fusion/split_reduce_concat_fp32.c | 42 + .../online_fusion/split_reduce_concat_fp32.h | 33 + .../kernel/cpu/nnacl_c/fp32/pack_fp32.c | 2078 ++++++++ .../kernel/cpu/nnacl_c/fp32/pack_fp32.h | 130 + .../kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c | 292 ++ .../kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h | 38 + .../litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c | 83 + .../litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h | 40 + .../kernel/cpu/nnacl_c/fp32/pooling_fp32.c | 786 +++ .../kernel/cpu/nnacl_c/fp32/pooling_fp32.h | 48 + .../cpu/nnacl_c/fp32/pooling_fp32_simd.h.in | 116 + .../kernel/cpu/nnacl_c/fp32/power_fp32.c | 70 + .../kernel/cpu/nnacl_c/fp32/power_fp32.h | 41 + .../cpu/nnacl_c/fp32/power_fp32_simd.h.in | 94 + .../kernel/cpu/nnacl_c/fp32/prelu_fp32.c | 164 + .../kernel/cpu/nnacl_c/fp32/prelu_fp32.h | 31 + .../kernel/cpu/nnacl_c/fp32/prior_box_fp32.h | 41 + .../cpu/nnacl_c/fp32/ragged_range_fp32.c | 52 + .../cpu/nnacl_c/fp32/ragged_range_fp32.h | 26 + .../kernel/cpu/nnacl_c/fp32/range_fp32.h | 34 + .../kernel/cpu/nnacl_c/fp32/rank_fp32.h | 32 + .../kernel/cpu/nnacl_c/fp32/reduce_fp32.c | 359 ++ .../kernel/cpu/nnacl_c/fp32/reduce_fp32.h | 69 + .../cpu/nnacl_c/fp32/reduce_fp32_simd.h.in | 220 + .../kernel/cpu/nnacl_c/fp32/resize_fp32.c | 598 +++ .../kernel/cpu/nnacl_c/fp32/resize_fp32.h | 74 + .../kernel/cpu/nnacl_c/fp32/reverse_fp32.c | 28 + .../kernel/cpu/nnacl_c/fp32/reverse_fp32.h | 31 + .../cpu/nnacl_c/fp32/reverse_sequence_fp32.c | 40 + .../cpu/nnacl_c/fp32/reverse_sequence_fp32.h | 33 + .../kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c | 147 + .../kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h | 34 + .../cpu/nnacl_c/fp32/roi_pooling_fp32.c | 97 + .../cpu/nnacl_c/fp32/roi_pooling_fp32.h | 54 + .../kernel/cpu/nnacl_c/fp32/scale_fp32.c | 304 ++ .../kernel/cpu/nnacl_c/fp32/scale_fp32.h | 35 + .../kernel/cpu/nnacl_c/fp32/softmax_fp32.c | 125 + .../kernel/cpu/nnacl_c/fp32/softmax_fp32.h | 33 + .../cpu/nnacl_c/fp32/softmax_fp32_simd.h.in | 80 + .../nnacl_c/fp32/softmax_grad_fusion_fp32.c | 36 + .../nnacl_c/fp32/softmax_grad_fusion_fp32.h | 31 + .../fp32/softmax_grad_fusion_fp32_simd.h.in | 55 + .../cpu/nnacl_c/fp32/space_to_batch_fp32.c | 63 + .../cpu/nnacl_c/fp32/space_to_batch_fp32.h | 50 + .../cpu/nnacl_c/fp32/sparse_to_dense_fp32.c | 77 + .../cpu/nnacl_c/fp32/sparse_to_dense_fp32.h | 31 + .../kernel/cpu/nnacl_c/fp32/splice_fp32.c | 31 + .../kernel/cpu/nnacl_c/fp32/splice_fp32.h | 26 + .../cpu/nnacl_c/fp32/squared_difference.c | 32 + .../cpu/nnacl_c/fp32/squared_difference.h | 36 + .../cpu/nnacl_c/fp32/strided_slice_fp32.c | 125 + .../cpu/nnacl_c/fp32/strided_slice_fp32.h | 35 + .../litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c | 150 + .../litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h | 44 + .../cpu/nnacl_c/fp32/sub_fp32_simd.h.in | 199 + .../kernel/cpu/nnacl_c/fp32/topk_fp32.c | 106 + .../kernel/cpu/nnacl_c/fp32/topk_fp32.h | 50 + .../kernel/cpu/nnacl_c/fp32/transpose_fp32.c | 248 + .../kernel/cpu/nnacl_c/fp32/transpose_fp32.h | 35 + .../cpu/nnacl_c/fp32/transpose_server_fp32.c | 239 + .../cpu/nnacl_c/fp32/transpose_server_fp32.h | 40 + .../kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c | 179 + .../kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h | 42 + .../kernel/cpu/nnacl_c/fp32/unique_fp32.c | 67 + .../kernel/cpu/nnacl_c/fp32/unique_fp32.h | 36 + .../kernel/cpu/nnacl_c/fp32/where_fp32.c | 35 + .../kernel/cpu/nnacl_c/fp32/where_fp32.h | 32 + .../kernel/cpu/nnacl_c/fp32/winograd_avx.c | 2233 +++++++++ .../kernel/cpu/nnacl_c/fp32/winograd_avx.h | 299 ++ .../cpu/nnacl_c/fp32/winograd_transform.c | 281 ++ .../cpu/nnacl_c/fp32/winograd_transform.h | 51 + .../kernel/cpu/nnacl_c/fp32/winograd_utils.c | 4289 +++++++++++++++++ .../kernel/cpu/nnacl_c/fp32/winograd_utils.h | 373 ++ .../nnacl_c/fp32_grad/activation_grad_fp32.c | 161 + .../nnacl_c/fp32_grad/activation_grad_fp32.h | 49 + .../fp32_grad/activation_grad_simd.h.in | 50 + .../fp32_grad/apply_proximal_adagrad_fp32.c | 48 + .../fp32_grad/apply_proximal_adagrad_fp32.h | 31 + .../apply_proximal_adagrad_fp32_simd.h.in | 68 + .../apply_proximal_gradient_descent_fp32.c | 44 + .../apply_proximal_gradient_descent_fp32.h | 31 + ...y_proximal_gradient_descent_fp32_simd.h.in | 64 + .../cpu/nnacl_c/fp32_grad/arithmetic_grad.c | 154 + .../cpu/nnacl_c/fp32_grad/arithmetic_grad.h | 38 + .../cpu/nnacl_c/fp32_grad/batch_norm_grad.c | 100 + .../cpu/nnacl_c/fp32_grad/batch_norm_grad.h | 37 + .../nnacl_c/fp32_grad/batch_norm_parameter.h | 28 + .../nnacl_c/fp32_grad/binary_cross_entropy.c | 75 + .../nnacl_c/fp32_grad/binary_cross_entropy.h | 36 + .../fp32_grad/binary_cross_entropy_grad.c | 56 + .../fp32_grad/binary_cross_entropy_grad.h | 36 + .../fp32_grad/convolution_grad_filter.c | 380 ++ .../fp32_grad/convolution_grad_filter.h | 32 + .../fp32_grad/convolution_grad_input.c | 100 + .../fp32_grad/convolution_grad_input.h | 32 + .../cpu/nnacl_c/fp32_grad/dropout_grad.c | 23 + .../cpu/nnacl_c/fp32_grad/dropout_grad.h | 31 + .../cpu/nnacl_c/fp32_grad/dropout_parameter.h | 27 + .../kernel/cpu/nnacl_c/fp32_grad/gemm.c | 855 ++++ .../kernel/cpu/nnacl_c/fp32_grad/gemm.h | 45 + .../cpu/nnacl_c/fp32_grad/layernorm_grad.c | 64 + .../cpu/nnacl_c/fp32_grad/layernorm_grad.h | 29 + .../fp32_grad/layernormgrad_parameter.h | 27 + .../cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c | 237 + .../cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h | 70 + .../cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c | 147 + .../cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h | 36 + .../cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c | 42 + .../cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h | 31 + .../kernel/cpu/nnacl_c/fp32_grad/optimizer.h | 40 + .../kernel/cpu/nnacl_c/fp32_grad/pack_ext.c | 301 ++ .../kernel/cpu/nnacl_c/fp32_grad/pack_ext.h | 39 + .../cpu/nnacl_c/fp32_grad/pooling_grad.c | 190 + .../cpu/nnacl_c/fp32_grad/pooling_grad.h | 34 + .../cpu/nnacl_c/fp32_grad/reduce_grad.c | 89 + .../cpu/nnacl_c/fp32_grad/reduce_grad.h | 30 + .../cpu/nnacl_c/fp32_grad/resize_grad.c | 149 + .../cpu/nnacl_c/fp32_grad/resize_grad.h | 33 + .../nnacl_c/fp32_grad/resize_grad_parameter.h | 34 + .../cpu/nnacl_c/fp32_grad/smooth_l1_loss.h | 27 + .../softmax_cross_entropy_with_logits.c | 43 + .../softmax_cross_entropy_with_logits.h | 33 + .../softmax_crossentropy_parameter.h | 36 + .../cpu/nnacl_c/fp32_grad/softmax_grad.c | 59 + .../cpu/nnacl_c/fp32_grad/softmax_grad.h | 33 + .../nnacl_c/fp32_grad/softmax_grad_utils.c | 102 + .../nnacl_c/fp32_grad/softmax_grad_utils.h | 33 + .../nnacl_c/fp32_grad/strided_slice_grad.c | 68 + .../nnacl_c/fp32_grad/strided_slice_grad.h | 30 + .../kernel/cpu/nnacl_c/fp32_grad/utils.h | 72 + .../fp32_sparse/matmul_sparse_x1_fp32.c | 54 + .../fp32_sparse/matmul_sparse_x1_fp32.h | 41 + .../kernel/cpu/nnacl_c/gather_nd_parameter.h | 26 + .../kernel/cpu/nnacl_c/gather_parameter.h | 28 + .../kernel/cpu/nnacl_c/gelu_parameter.h | 28 + .../litert/kernel/cpu/nnacl_c/glu_parameter.h | 26 + .../cpu/nnacl_c/grid_sampler_parameter.h | 28 + .../kernel/cpu/nnacl_c/group_norm_parameter.h | 41 + .../litert/kernel/cpu/nnacl_c/gru_parameter.h | 38 + .../cpu/nnacl_c/infer/activation_grad_infer.c | 45 + .../cpu/nnacl_c/infer/activation_grad_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/adam_infer.c | 44 + .../kernel/cpu/nnacl_c/infer/adam_infer.h | 31 + .../nnacl_c/infer/adam_weight_decay_infer.c | 56 + .../nnacl_c/infer/adam_weight_decay_infer.h | 32 + .../cpu/nnacl_c/infer/add_sub_grad_infer.c | 62 + .../cpu/nnacl_c/infer/add_sub_grad_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/addn_infer.c | 86 + .../kernel/cpu/nnacl_c/infer/addn_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/affine_infer.c | 122 + .../kernel/cpu/nnacl_c/infer/affine_infer.h | 32 + .../cpu/nnacl_c/infer/all_gather_infer.c | 54 + .../cpu/nnacl_c/infer/all_gather_infer.h | 33 + .../cpu/nnacl_c/infer/apply_momentum_infer.c | 47 + .../cpu/nnacl_c/infer/apply_momentum_infer.h | 31 + .../cpu/nnacl_c/infer/argmin_max_infer.c | 83 + .../cpu/nnacl_c/infer/argmin_max_infer.h | 32 + .../nnacl_c/infer/arithmetic_compare_infer.c | 36 + .../nnacl_c/infer/arithmetic_compare_infer.h | 31 + .../cpu/nnacl_c/infer/arithmetic_grad_infer.c | 107 + .../cpu/nnacl_c/infer/arithmetic_grad_infer.h | 31 + .../cpu/nnacl_c/infer/arithmetic_infer.c | 123 + .../cpu/nnacl_c/infer/arithmetic_infer.h | 32 + .../cpu/nnacl_c/infer/assert_op_infer.c | 25 + .../cpu/nnacl_c/infer/assert_op_infer.h | 31 + .../cpu/nnacl_c/infer/assign_add_infer.c | 38 + .../cpu/nnacl_c/infer/assign_add_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/assign_infer.c | 41 + .../kernel/cpu/nnacl_c/infer/assign_infer.h | 31 + .../cpu/nnacl_c/infer/attention_infer.c | 74 + .../cpu/nnacl_c/infer/attention_infer.h | 31 + .../nnacl_c/infer/audio_spectrogram_infer.c | 75 + .../nnacl_c/infer/audio_spectrogram_infer.h | 37 + .../cpu/nnacl_c/infer/batch_to_space_infer.c | 144 + .../cpu/nnacl_c/infer/batch_to_space_infer.h | 32 + .../cpu/nnacl_c/infer/bias_grad_infer.c | 41 + .../cpu/nnacl_c/infer/bias_grad_infer.h | 31 + .../infer/binary_cross_entropy_infer.c | 40 + .../infer/binary_cross_entropy_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/bn_grad_infer.c | 42 + .../kernel/cpu/nnacl_c/infer/bn_grad_infer.h | 31 + .../cpu/nnacl_c/infer/broadcast_to_infer.c | 200 + .../cpu/nnacl_c/infer/broadcast_to_infer.h | 36 + .../nnacl_c/infer/cast_gather_reduce_infer.c | 77 + .../nnacl_c/infer/cast_gather_reduce_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/cast_infer.c | 52 + .../kernel/cpu/nnacl_c/infer/cast_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/common_infer.c | 338 ++ .../kernel/cpu/nnacl_c/infer/common_infer.h | 94 + .../kernel/cpu/nnacl_c/infer/concat_infer.c | 97 + .../kernel/cpu/nnacl_c/infer/concat_infer.h | 32 + .../nnacl_c/infer/constant_of_shape_infer.c | 71 + .../nnacl_c/infer/constant_of_shape_infer.h | 32 + .../infer/control/tensor_array_infer.c | 47 + .../infer/control/tensor_array_infer.h | 31 + .../infer/control/tensor_array_read_infer.c | 43 + .../infer/control/tensor_array_read_infer.h | 31 + .../infer/control/tensor_array_write_infer.c | 54 + .../infer/control/tensor_array_write_infer.h | 31 + .../control/tensorlist_fromtensor_infer.c | 81 + .../control/tensorlist_fromtensor_infer.h | 31 + .../infer/control/tensorlist_getitem_infer.c | 102 + .../infer/control/tensorlist_getitem_infer.h | 32 + .../infer/control/tensorlist_reserve_infer.c | 84 + .../infer/control/tensorlist_reserve_infer.h | 31 + .../infer/control/tensorlist_setitem_infer.c | 129 + .../infer/control/tensorlist_setitem_infer.h | 31 + .../infer/control/tensorlist_stack_infer.c | 96 + .../infer/control/tensorlist_stack_infer.h | 31 + .../nnacl_c/infer/conv2d_grad_filter_infer.c | 61 + .../nnacl_c/infer/conv2d_grad_filter_infer.h | 32 + .../nnacl_c/infer/conv2d_grad_input_infer.c | 63 + .../nnacl_c/infer/conv2d_grad_input_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/conv2d_infer.c | 169 + .../kernel/cpu/nnacl_c/infer/conv2d_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/conv3d_infer.c | 27 + .../kernel/cpu/nnacl_c/infer/conv3d_infer.h | 32 + .../cpu/nnacl_c/infer/crop_and_resize_infer.c | 69 + .../cpu/nnacl_c/infer/crop_and_resize_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/crop_infer.c | 44 + .../kernel/cpu/nnacl_c/infer/crop_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/cumsum_infer.c | 41 + .../kernel/cpu/nnacl_c/infer/cumsum_infer.h | 31 + .../cpu/nnacl_c/infer/custom_gru_infer.c | 45 + .../cpu/nnacl_c/infer/custom_gru_infer.h | 30 + .../cpu/nnacl_c/infer/custom_is_inf_infer.c | 40 + .../cpu/nnacl_c/infer/custom_is_inf_infer.h | 31 + .../nnacl_c/infer/custom_masked_fill_infer.c | 37 + .../nnacl_c/infer/custom_masked_fill_infer.h | 31 + .../infer/custom_tensor_scatter_max_infer.c | 37 + .../infer/custom_tensor_scatter_max_infer.h | 31 + .../cpu/nnacl_c/infer/decoder_layer_infer.c | 36 + .../cpu/nnacl_c/infer/decoder_layer_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/deconv2d_infer.c | 119 + .../kernel/cpu/nnacl_c/infer/deconv2d_infer.h | 32 + .../cpu/nnacl_c/infer/depth_to_space_infer.c | 60 + .../cpu/nnacl_c/infer/depth_to_space_infer.h | 32 + .../nnacl_c/infer/depthwise_conv2d_infer.c | 81 + .../nnacl_c/infer/depthwise_conv2d_infer.h | 32 + .../infer/detection_post_process_infer.c | 83 + .../infer/detection_post_process_infer.h | 32 + .../cpu/nnacl_c/infer/dropout_grad_infer.c | 40 + .../cpu/nnacl_c/infer/dropout_grad_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/dropout_infer.c | 42 + .../kernel/cpu/nnacl_c/infer/dropout_infer.h | 31 + .../cpu/nnacl_c/infer/dynamic_quant_infer.c | 42 + .../cpu/nnacl_c/infer/dynamic_quant_infer.h | 31 + .../nnacl_c/infer/embedding_lookup_infer.c | 77 + .../nnacl_c/infer/embedding_lookup_infer.h | 31 + .../cpu/nnacl_c/infer/encoder_layer_infer.c | 36 + .../cpu/nnacl_c/infer/encoder_layer_infer.h | 31 + .../cpu/nnacl_c/infer/expand_dims_infer.c | 64 + .../cpu/nnacl_c/infer/expand_dims_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/fft_imag_infer.c | 25 + .../kernel/cpu/nnacl_c/infer/fft_imag_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/fft_real_infer.c | 25 + .../kernel/cpu/nnacl_c/infer/fft_real_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/fill_infer.c | 65 + .../kernel/cpu/nnacl_c/infer/fill_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/fillv2_infer.c | 62 + .../kernel/cpu/nnacl_c/infer/fillv2_infer.h | 31 + .../cpu/nnacl_c/infer/flatten_grad_infer.c | 43 + .../cpu/nnacl_c/infer/flatten_grad_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/flatten_infer.c | 64 + .../kernel/cpu/nnacl_c/infer/flatten_infer.h | 31 + .../nnacl_c/infer/format_transpose_infer.c | 67 + .../nnacl_c/infer/format_transpose_infer.h | 31 + .../cpu/nnacl_c/infer/fse_decoder_infer.c | 35 + .../cpu/nnacl_c/infer/fse_decoder_infer.h | 31 + .../cpu/nnacl_c/infer/full_connection_infer.c | 92 + .../cpu/nnacl_c/infer/full_connection_infer.h | 32 + .../cpu/nnacl_c/infer/fused_batchnorm_infer.c | 45 + .../cpu/nnacl_c/infer/fused_batchnorm_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/gather_d_infer.c | 47 + .../kernel/cpu/nnacl_c/infer/gather_d_infer.h | 33 + .../kernel/cpu/nnacl_c/infer/gather_infer.c | 83 + .../kernel/cpu/nnacl_c/infer/gather_infer.h | 32 + .../cpu/nnacl_c/infer/gather_nd_infer.c | 59 + .../cpu/nnacl_c/infer/gather_nd_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/glu_infer.c | 48 + .../kernel/cpu/nnacl_c/infer/glu_infer.h | 32 + .../cpu/nnacl_c/infer/grid_sampler_infer.c | 47 + .../cpu/nnacl_c/infer/grid_sampler_infer.h | 32 + .../infer/group_conv2d_grad_input_infer.c | 45 + .../infer/group_conv2d_grad_input_infer.h | 32 + .../cpu/nnacl_c/infer/group_norm_infer.c | 37 + .../cpu/nnacl_c/infer/group_norm_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/gru_infer.c | 92 + .../kernel/cpu/nnacl_c/infer/gru_infer.h | 32 + .../litert/kernel/cpu/nnacl_c/infer/infer.h | 33 + .../kernel/cpu/nnacl_c/infer/infer_register.c | 450 ++ .../kernel/cpu/nnacl_c/infer/infer_register.h | 39 + .../cpu/nnacl_c/infer/instance_norm_infer.c | 46 + .../cpu/nnacl_c/infer/instance_norm_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/invalid_infer.c | 32 + .../kernel/cpu/nnacl_c/infer/invalid_infer.h | 31 + .../nnacl_c/infer/invert_permutation_infer.c | 43 + .../nnacl_c/infer/invert_permutation_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/isfinite_infer.c | 42 + .../kernel/cpu/nnacl_c/infer/isfinite_infer.h | 31 + .../cpu/nnacl_c/infer/layer_norm_grad_infer.c | 57 + .../cpu/nnacl_c/infer/layer_norm_grad_infer.h | 31 + .../cpu/nnacl_c/infer/layer_norm_infer.c | 68 + .../cpu/nnacl_c/infer/layer_norm_infer.h | 32 + .../cpu/nnacl_c/infer/lin_space_infer.c | 48 + .../cpu/nnacl_c/infer/lin_space_infer.h | 31 + .../cpu/nnacl_c/infer/log_softmax_infer.c | 51 + .../cpu/nnacl_c/infer/log_softmax_infer.h | 32 + .../cpu/nnacl_c/infer/lstm_grad_data_infer.c | 60 + .../cpu/nnacl_c/infer/lstm_grad_data_infer.h | 32 + .../cpu/nnacl_c/infer/lstm_grad_infer.c | 54 + .../cpu/nnacl_c/infer/lstm_grad_infer.h | 32 + .../nnacl_c/infer/lstm_grad_weight_infer.c | 61 + .../nnacl_c/infer/lstm_grad_weight_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/lstm_infer.c | 161 + .../kernel/cpu/nnacl_c/infer/lstm_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/matmul_infer.c | 148 + .../kernel/cpu/nnacl_c/infer/matmul_infer.h | 32 + .../cpu/nnacl_c/infer/max_min_grad_infer.c | 65 + .../cpu/nnacl_c/infer/max_min_grad_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/mfcc_infer.c | 48 + .../kernel/cpu/nnacl_c/infer/mfcc_infer.h | 36 + .../cpu/nnacl_c/infer/nllloss_grad_infer.c | 54 + .../cpu/nnacl_c/infer/nllloss_grad_infer.h | 33 + .../kernel/cpu/nnacl_c/infer/nllloss_infer.c | 52 + .../kernel/cpu/nnacl_c/infer/nllloss_infer.h | 33 + .../nnacl_c/infer/non_max_suppression_infer.c | 34 + .../nnacl_c/infer/non_max_suppression_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/one_hot_infer.c | 60 + .../kernel/cpu/nnacl_c/infer/one_hot_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/pad_infer.c | 64 + .../kernel/cpu/nnacl_c/infer/pad_infer.h | 32 + .../cpu/nnacl_c/infer/pooling_grad_infer.c | 74 + .../cpu/nnacl_c/infer/pooling_grad_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/pooling_infer.c | 107 + .../kernel/cpu/nnacl_c/infer/pooling_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/power_infer.c | 56 + .../kernel/cpu/nnacl_c/infer/power_infer.h | 32 + .../cpu/nnacl_c/infer/prior_box_infer.c | 87 + .../cpu/nnacl_c/infer/prior_box_infer.h | 32 + .../nnacl_c/infer/quant_dtype_cast_infer.c | 41 + .../nnacl_c/infer/quant_dtype_cast_infer.h | 37 + .../cpu/nnacl_c/infer/ragged_range_infer.c | 129 + .../cpu/nnacl_c/infer/ragged_range_infer.h | 32 + .../cpu/nnacl_c/infer/random_normal_infer.c | 38 + .../cpu/nnacl_c/infer/random_normal_infer.h | 31 + .../infer/random_standard_normal_infer.c | 52 + .../infer/random_standard_normal_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/range_infer.c | 91 + .../kernel/cpu/nnacl_c/infer/range_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/rank_infer.c | 38 + .../kernel/cpu/nnacl_c/infer/rank_infer.h | 31 + .../cpu/nnacl_c/infer/reduce_concat_infer.c | 95 + .../cpu/nnacl_c/infer/reduce_concat_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/reduce_infer.c | 140 + .../kernel/cpu/nnacl_c/infer/reduce_infer.h | 32 + .../cpu/nnacl_c/infer/reduce_scatter_infer.c | 56 + .../cpu/nnacl_c/infer/reduce_scatter_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/reshape_infer.c | 221 + .../kernel/cpu/nnacl_c/infer/reshape_infer.h | 32 + .../cpu/nnacl_c/infer/resize_grad_infer.c | 59 + .../cpu/nnacl_c/infer/resize_grad_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/resize_infer.c | 129 + .../kernel/cpu/nnacl_c/infer/resize_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/rfft_infer.c | 47 + .../kernel/cpu/nnacl_c/infer/rfft_infer.h | 36 + .../cpu/nnacl_c/infer/roi_pooling_infer.c | 51 + .../cpu/nnacl_c/infer/roi_pooling_infer.h | 32 + .../cpu/nnacl_c/infer/scatter_nd_infer.c | 45 + .../cpu/nnacl_c/infer/scatter_nd_infer.h | 31 + .../nnacl_c/infer/scatter_nd_update_infer.c | 59 + .../nnacl_c/infer/scatter_nd_update_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/select_infer.c | 62 + .../kernel/cpu/nnacl_c/infer/select_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/sgd_infer.c | 43 + .../kernel/cpu/nnacl_c/infer/sgd_infer.h | 31 + .../cpu/nnacl_c/infer/shape_fusion_infer.c | 97 + .../cpu/nnacl_c/infer/shape_fusion_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/shape_infer.c | 40 + .../kernel/cpu/nnacl_c/infer/shape_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/size_infer.c | 41 + .../kernel/cpu/nnacl_c/infer/size_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/slice_infer.c | 126 + .../kernel/cpu/nnacl_c/infer/slice_infer.h | 32 + .../infer/softmax_cross_entropy_infer.c | 43 + .../infer/softmax_cross_entropy_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/softmax_infer.c | 49 + .../kernel/cpu/nnacl_c/infer/softmax_infer.h | 32 + .../cpu/nnacl_c/infer/space_to_batch_infer.c | 64 + .../cpu/nnacl_c/infer/space_to_batch_infer.h | 32 + .../nnacl_c/infer/space_to_batch_nd_infer.c | 143 + .../nnacl_c/infer/space_to_batch_nd_infer.h | 32 + .../cpu/nnacl_c/infer/space_to_depth_infer.c | 61 + .../cpu/nnacl_c/infer/space_to_depth_infer.h | 32 + .../infer/sparse_fill_empty_rows_infer.c | 49 + .../infer/sparse_fill_empty_rows_infer.h | 31 + .../cpu/nnacl_c/infer/sparse_reshape_infer.c | 53 + .../cpu/nnacl_c/infer/sparse_reshape_infer.h | 31 + .../nnacl_c/infer/sparse_segment_sum_infer.c | 37 + .../nnacl_c/infer/sparse_segment_sum_infer.h | 31 + ..._softmax_cross_entropy_with_logits_infer.c | 45 + ..._softmax_cross_entropy_with_logits_infer.h | 31 + .../cpu/nnacl_c/infer/sparse_to_dense_infer.c | 51 + .../cpu/nnacl_c/infer/sparse_to_dense_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/splice_infer.c | 56 + .../kernel/cpu/nnacl_c/infer/splice_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/split_infer.c | 120 + .../kernel/cpu/nnacl_c/infer/split_infer.h | 32 + .../nnacl_c/infer/split_reduce_concat_infer.c | 45 + .../nnacl_c/infer/split_reduce_concat_infer.h | 31 + .../nnacl_c/infer/split_with_over_lap_infer.c | 84 + .../nnacl_c/infer/split_with_over_lap_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/squeeze_infer.c | 81 + .../kernel/cpu/nnacl_c/infer/squeeze_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/stack_infer.c | 70 + .../kernel/cpu/nnacl_c/infer/stack_infer.h | 32 + .../nnacl_c/infer/strided_slice_grad_infer.c | 163 + .../nnacl_c/infer/strided_slice_grad_infer.h | 32 + .../cpu/nnacl_c/infer/strided_slice_infer.c | 483 ++ .../cpu/nnacl_c/infer/strided_slice_infer.h | 32 + .../string/custom_extract_features_infer.c | 49 + .../string/custom_extract_features_infer.h | 31 + .../infer/string/custom_normalize_infer.c | 48 + .../infer/string/custom_normalize_infer.h | 32 + .../infer/string/custom_predict_infer.c | 43 + .../infer/string/custom_predict_infer.h | 36 + .../infer/string/hashtable_lookup_infer.c | 50 + .../infer/string/hashtable_lookup_infer.h | 31 + .../infer/string/lsh_projection_infer.c | 53 + .../infer/string/lsh_projection_infer.h | 32 + .../nnacl_c/infer/string/skip_gram_infer.c | 37 + .../nnacl_c/infer/string/skip_gram_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/tile_infer.c | 111 + .../kernel/cpu/nnacl_c/infer/tile_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/topk_infer.c | 66 + .../kernel/cpu/nnacl_c/infer/topk_infer.h | 32 + .../cpu/nnacl_c/infer/transpose_infer.c | 137 + .../cpu/nnacl_c/infer/transpose_infer.h | 32 + .../cpu/nnacl_c/infer/triu_tril_infer.c | 42 + .../cpu/nnacl_c/infer/triu_tril_infer.h | 32 + .../cpu/nnacl_c/infer/uniform_real_infer.c | 49 + .../cpu/nnacl_c/infer/uniform_real_infer.h | 31 + .../kernel/cpu/nnacl_c/infer/unique_infer.c | 42 + .../kernel/cpu/nnacl_c/infer/unique_infer.h | 31 + .../infer/unsorted_segment_sum_infer.c | 49 + .../infer/unsorted_segment_sum_infer.h | 36 + .../cpu/nnacl_c/infer/unsqueeze_infer.c | 79 + .../cpu/nnacl_c/infer/unsqueeze_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/unstack_infer.c | 59 + .../kernel/cpu/nnacl_c/infer/unstack_infer.h | 32 + .../kernel/cpu/nnacl_c/infer/where_infer.c | 91 + .../kernel/cpu/nnacl_c/infer/where_infer.h | 31 + .../cpu/nnacl_c/instance_norm_parameter.h | 32 + .../litert/kernel/cpu/nnacl_c/int8/add_int8.c | 531 ++ .../litert/kernel/cpu/nnacl_c/int8/add_int8.h | 70 + .../cpu/nnacl_c/int8/arg_min_max_int8.c | 237 + .../cpu/nnacl_c/int8/arg_min_max_int8.h | 41 + .../kernel/cpu/nnacl_c/int8/arithmetic_int8.c | 137 + .../kernel/cpu/nnacl_c/int8/arithmetic_int8.h | 51 + .../cpu/nnacl_c/int8/arithmetic_self_int8.c | 305 ++ .../cpu/nnacl_c/int8/arithmetic_self_int8.h | 59 + .../cpu/nnacl_c/int8/batch_to_space_int8.c | 110 + .../cpu/nnacl_c/int8/batch_to_space_int8.h | 33 + .../kernel/cpu/nnacl_c/int8/batchnorm_int8.c | 33 + .../kernel/cpu/nnacl_c/int8/batchnorm_int8.h | 34 + .../cpu/nnacl_c/int8/common_func_int8.c | 74 + .../cpu/nnacl_c/int8/common_func_int8.h | 95 + .../kernel/cpu/nnacl_c/int8/concat_int8.c | 57 + .../kernel/cpu/nnacl_c/int8/concat_int8.h | 33 + .../kernel/cpu/nnacl_c/int8/conv1x1_int8.c | 40 + .../kernel/cpu/nnacl_c/int8/conv1x1_int8.h | 46 + .../kernel/cpu/nnacl_c/int8/conv3x3_int8.c | 902 ++++ .../kernel/cpu/nnacl_c/int8/conv3x3_int8.h | 48 + .../cpu/nnacl_c/int8/conv_depthwise_int8.c | 825 ++++ .../cpu/nnacl_c/int8/conv_depthwise_int8.h | 49 + .../kernel/cpu/nnacl_c/int8/conv_int8.c | 913 ++++ .../kernel/cpu/nnacl_c/int8/conv_int8.h | 44 + .../kernel/cpu/nnacl_c/int8/crop_int8.c | 236 + .../kernel/cpu/nnacl_c/int8/crop_int8.h | 31 + .../kernel/cpu/nnacl_c/int8/deconv_int8.c | 150 + .../kernel/cpu/nnacl_c/int8/deconv_int8.h | 46 + .../cpu/nnacl_c/int8/depth_to_space_int8.c | 51 + .../cpu/nnacl_c/int8/depth_to_space_int8.h | 32 + .../litert/kernel/cpu/nnacl_c/int8/div_int8.c | 67 + .../litert/kernel/cpu/nnacl_c/int8/div_int8.h | 37 + .../cpu/nnacl_c/int8/dynamic_gather_int8.c | 76 + .../cpu/nnacl_c/int8/dynamic_gather_int8.h | 40 + .../cpu/nnacl_c/int8/dynamic_matmul_int8.c | 420 ++ .../cpu/nnacl_c/int8/dynamic_matmul_int8.h | 74 + .../cpu/nnacl_c/int8/dynamic_quant_int8.c | 91 + .../cpu/nnacl_c/int8/dynamic_quant_int8.h | 34 + .../kernel/cpu/nnacl_c/int8/fixed_point.c | 276 ++ .../kernel/cpu/nnacl_c/int8/fixed_point.h | 74 + .../kernel/cpu/nnacl_c/int8/gatherNd_int8.c | 34 + .../kernel/cpu/nnacl_c/int8/gatherNd_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/gather_int8.c | 68 + .../kernel/cpu/nnacl_c/int8/gather_int8.h | 35 + .../kernel/cpu/nnacl_c/int8/hswish_int8.c | 53 + .../kernel/cpu/nnacl_c/int8/hswish_int8.h | 43 + .../kernel/cpu/nnacl_c/int8/l2_norm_int8.c | 41 + .../kernel/cpu/nnacl_c/int8/l2_norm_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/layer_norm_int8.c | 74 + .../kernel/cpu/nnacl_c/int8/layer_norm_int8.h | 35 + .../kernel/cpu/nnacl_c/int8/leaky_relu_int8.c | 52 + .../kernel/cpu/nnacl_c/int8/leaky_relu_int8.h | 31 + .../kernel/cpu/nnacl_c/int8/matmul_int8.c | 839 ++++ .../kernel/cpu/nnacl_c/int8/matmul_int8.h | 93 + .../litert/kernel/cpu/nnacl_c/int8/mul_int8.c | 238 + .../litert/kernel/cpu/nnacl_c/int8/mul_int8.h | 39 + .../kernel/cpu/nnacl_c/int8/pack_int8.c | 452 ++ .../kernel/cpu/nnacl_c/int8/pack_int8.h | 56 + .../litert/kernel/cpu/nnacl_c/int8/pad_int8.c | 75 + .../litert/kernel/cpu/nnacl_c/int8/pad_int8.h | 35 + .../kernel/cpu/nnacl_c/int8/pooling_int8.c | 516 ++ .../kernel/cpu/nnacl_c/int8/pooling_int8.h | 50 + .../kernel/cpu/nnacl_c/int8/power_int8.c | 48 + .../kernel/cpu/nnacl_c/int8/power_int8.h | 33 + .../cpu/nnacl_c/int8/quant_dtype_cast_int8.c | 437 ++ .../cpu/nnacl_c/int8/quant_dtype_cast_int8.h | 56 + .../litert/kernel/cpu/nnacl_c/int8/quantize.c | 161 + .../litert/kernel/cpu/nnacl_c/int8/quantize.h | 222 + .../kernel/cpu/nnacl_c/int8/reduce_int8.c | 597 +++ .../kernel/cpu/nnacl_c/int8/reduce_int8.h | 70 + .../kernel/cpu/nnacl_c/int8/relux_int8.c | 30 + .../kernel/cpu/nnacl_c/int8/relux_int8.h | 43 + .../kernel/cpu/nnacl_c/int8/reshape_int8.c | 40 + .../kernel/cpu/nnacl_c/int8/reshape_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/resize_int8.c | 233 + .../kernel/cpu/nnacl_c/int8/resize_int8.h | 50 + .../kernel/cpu/nnacl_c/int8/scale_int8.c | 164 + .../kernel/cpu/nnacl_c/int8/scale_int8.h | 35 + .../kernel/cpu/nnacl_c/int8/sigmoid_int8.c | 26 + .../kernel/cpu/nnacl_c/int8/sigmoid_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/slice_int8.c | 97 + .../kernel/cpu/nnacl_c/int8/slice_int8.h | 35 + .../kernel/cpu/nnacl_c/int8/softmax_int8.c | 68 + .../kernel/cpu/nnacl_c/int8/softmax_int8.h | 35 + .../cpu/nnacl_c/int8/space_to_batch_int8.c | 88 + .../cpu/nnacl_c/int8/space_to_batch_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/split_int8.c | 75 + .../kernel/cpu/nnacl_c/int8/split_int8.h | 33 + .../kernel/cpu/nnacl_c/int8/squeeze_int8.c | 39 + .../kernel/cpu/nnacl_c/int8/squeeze_int8.h | 32 + .../litert/kernel/cpu/nnacl_c/int8/sub_int8.c | 105 + .../litert/kernel/cpu/nnacl_c/int8/sub_int8.h | 32 + .../kernel/cpu/nnacl_c/int8/tanh_int8.c | 30 + .../kernel/cpu/nnacl_c/int8/tanh_int8.h | 43 + .../kernel/cpu/nnacl_c/int8/topk_int8.c | 57 + .../kernel/cpu/nnacl_c/int8/topk_int8.h | 36 + .../kernel/cpu/nnacl_c/int8/transpose_int8.c | 257 + .../kernel/cpu/nnacl_c/int8/transpose_int8.h | 36 + .../kernel/cpu/nnacl_c/int8/unsqueeze_int8.c | 33 + .../kernel/cpu/nnacl_c/int8/unsqueeze_int8.h | 33 + .../nnacl_c/intrinsics/avx/DeconvMatMulAvx.c | 188 + .../intrinsics/avx/PostFuncBiasReluC8.c | 352 ++ .../intrinsics/avx/TiledC8MatMulFp32.c | 274 ++ .../avx/WinogradPostFuncBiasReluC8.c | 357 ++ .../nnacl_c/intrinsics/avx/WinogradTransAvx.c | 355 ++ .../cpu/nnacl_c/intrinsics/avx/common_utils.c | 66 + .../cpu/nnacl_c/intrinsics/avx/common_utils.h | 157 + .../intrinsics/ms_simd_avx512_instructions.h | 446 ++ .../intrinsics/ms_simd_avx_instructions.h | 440 ++ .../cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c | 141 + .../cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h | 61 + .../nnacl_c/intrinsics/ms_simd_instructions.h | 563 +++ .../intrinsics/ms_simd_instructions_fp16.h | 162 + .../intrinsics/ms_simd_neon_instructions.h | 362 ++ .../intrinsics/ms_simd_sse_instructions.h | 403 ++ .../intrinsics/sse/ConvDwFp32IndirectRow.c | 120 + .../intrinsics/sse/ConvDwFp32Row_sse.c | 86 + .../intrinsics/sse/DepthwiseFp32_Sse.c | 327 ++ .../cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c | 243 + .../intrinsics/sse/PostFuncBiasReluC8.c | 131 + .../intrinsics/sse/TiledC4MatMulFp32.c | 161 + .../sse/WinogradPostFuncBiasReluC4.c | 349 ++ .../nnacl_c/intrinsics/sse/WinogradTrans.c | 376 ++ .../cpu/nnacl_c/intrinsics/sse/sse_common.h | 390 ++ .../src/litert/kernel/cpu/nnacl_c/kernel.c | 124 + .../src/litert/kernel/cpu/nnacl_c/kernel.h | 69 + .../kernel/cpu/nnacl_c/kernel/activation.c | 194 + .../kernel/cpu/nnacl_c/kernel/activation.h | 25 + .../litert/kernel/cpu/nnacl_c/kernel/addn.c | 144 + .../litert/kernel/cpu/nnacl_c/kernel/addn.h | 35 + .../kernel/cpu/nnacl_c/kernel/arg_min_max.c | 127 + .../kernel/cpu/nnacl_c/kernel/arg_min_max.h | 63 + .../kernel/cpu/nnacl_c/kernel/arithmetic.c | 725 +++ .../kernel/cpu/nnacl_c/kernel/arithmetic.h | 97 + .../cpu/nnacl_c/kernel/arithmetic_compare.c | 166 + .../cpu/nnacl_c/kernel/arithmetic_compare.h | 26 + .../cpu/nnacl_c/kernel/arithmetic_self.c | 199 + .../cpu/nnacl_c/kernel/arithmetic_self.h | 48 + .../kernel/cpu/nnacl_c/kernel/batch_norm.c | 134 + .../kernel/cpu/nnacl_c/kernel/batch_norm.h | 38 + .../cpu/nnacl_c/kernel/batch_to_space.c | 114 + .../cpu/nnacl_c/kernel/batch_to_space.h | 33 + .../kernel/cpu/nnacl_c/kernel/biasadd.c | 131 + .../kernel/cpu/nnacl_c/kernel/biasadd.h | 25 + .../litert/kernel/cpu/nnacl_c/kernel/cast.c | 209 + .../litert/kernel/cpu/nnacl_c/kernel/cast.h | 32 + .../litert/kernel/cpu/nnacl_c/kernel/clip.c | 123 + .../litert/kernel/cpu/nnacl_c/kernel/clip.h | 34 + .../litert/kernel/cpu/nnacl_c/kernel/concat.c | 287 ++ .../litert/kernel/cpu/nnacl_c/kernel/concat.h | 52 + .../cpu/nnacl_c/kernel/convolution_1x1.c | 365 ++ .../cpu/nnacl_c/kernel/convolution_1x1.h | 42 + .../cpu/nnacl_c/kernel/convolution_base.c | 209 + .../cpu/nnacl_c/kernel/convolution_base.h | 63 + .../cpu/nnacl_c/kernel/convolution_delegate.c | 365 ++ .../cpu/nnacl_c/kernel/convolution_delegate.h | 39 + .../nnacl_c/kernel/convolution_depthwise.c | 229 + .../nnacl_c/kernel/convolution_depthwise.h | 36 + .../kernel/convolution_depthwise_3x3.c | 154 + .../kernel/convolution_depthwise_3x3.h | 37 + .../kernel/convolution_depthwise_indirect.c | 227 + .../kernel/convolution_depthwise_indirect.h | 39 + .../nnacl_c/kernel/convolution_depthwise_sw.c | 200 + .../nnacl_c/kernel/convolution_depthwise_sw.h | 36 + .../kernel/convolution_depthwise_sw_avx.c | 216 + .../kernel/convolution_depthwise_sw_avx.h | 40 + .../cpu/nnacl_c/kernel/convolution_im2col.c | 81 + .../cpu/nnacl_c/kernel/convolution_im2col.h | 28 + .../nnacl_c/kernel/convolution_im2col_arm32.c | 45 + .../nnacl_c/kernel/convolution_im2col_arm32.h | 30 + .../nnacl_c/kernel/convolution_im2col_arm64.c | 72 + .../nnacl_c/kernel/convolution_im2col_arm64.h | 29 + .../nnacl_c/kernel/convolution_im2col_avx.c | 151 + .../nnacl_c/kernel/convolution_im2col_avx.h | 29 + .../kernel/convolution_im2col_avx512.c | 146 + .../kernel/convolution_im2col_avx512.h | 29 + .../nnacl_c/kernel/convolution_im2col_base.c | 246 + .../nnacl_c/kernel/convolution_im2col_base.h | 52 + .../nnacl_c/kernel/convolution_im2col_sse.c | 47 + .../nnacl_c/kernel/convolution_im2col_sse.h | 29 + .../nnacl_c/kernel/convolution_slidewindow.c | 227 + .../nnacl_c/kernel/convolution_slidewindow.h | 46 + .../cpu/nnacl_c/kernel/convolution_sw_1x1.c | 152 + .../cpu/nnacl_c/kernel/convolution_sw_1x1.h | 36 + .../cpu/nnacl_c/kernel/convolution_sw_arm64.c | 59 + .../cpu/nnacl_c/kernel/convolution_sw_arm64.h | 28 + .../cpu/nnacl_c/kernel/convolution_sw_avx.c | 71 + .../cpu/nnacl_c/kernel/convolution_sw_avx.h | 28 + .../cpu/nnacl_c/kernel/convolution_winograd.c | 76 + .../cpu/nnacl_c/kernel/convolution_winograd.h | 32 + .../kernel/convolution_winograd_arm32.c | 42 + .../kernel/convolution_winograd_arm32.h | 30 + .../kernel/convolution_winograd_arm64.c | 60 + .../kernel/convolution_winograd_arm64.h | 30 + .../nnacl_c/kernel/convolution_winograd_avx.c | 43 + .../nnacl_c/kernel/convolution_winograd_avx.h | 30 + .../kernel/convolution_winograd_base.c | 320 ++ .../kernel/convolution_winograd_base.h | 65 + .../nnacl_c/kernel/convolution_winograd_sse.c | 44 + .../nnacl_c/kernel/convolution_winograd_sse.h | 30 + .../litert/kernel/cpu/nnacl_c/kernel/crop.c | 96 + .../litert/kernel/cpu/nnacl_c/kernel/crop.h | 31 + .../cpu/nnacl_c/kernel/crop_and_resize.c | 190 + .../cpu/nnacl_c/kernel/crop_and_resize.h | 41 + .../kernel/cpu/nnacl_c/kernel/deconvolution.c | 337 ++ .../kernel/cpu/nnacl_c/kernel/deconvolution.h | 39 + .../nnacl_c/kernel/deconvolution_depthwise.c | 233 + .../nnacl_c/kernel/deconvolution_depthwise.h | 34 + .../nnacl_c/kernel/deconvolution_winograd.c | 551 +++ .../nnacl_c/kernel/deconvolution_winograd.h | 52 + .../cpu/nnacl_c/kernel/default_kernel_base.c | 55 + .../cpu/nnacl_c/kernel/default_kernel_base.h | 32 + .../cpu/nnacl_c/kernel/depth_to_space.c | 80 + .../cpu/nnacl_c/kernel/depth_to_space.h | 42 + .../litert/kernel/cpu/nnacl_c/kernel/exp.c | 86 + .../litert/kernel/cpu/nnacl_c/kernel/exp.h | 33 + .../kernel/f16/arithmetic_compare_f16.c | 110 + .../kernel/f16/arithmetic_compare_f16.h | 26 + .../cpu/nnacl_c/kernel/f16/arithmetic_f16.c | 195 + .../cpu/nnacl_c/kernel/f16/arithmetic_f16.h | 42 + .../cpu/nnacl_c/kernel/f16/concat_f16.c | 132 + .../cpu/nnacl_c/kernel/f16/concat_f16.h | 25 + .../cpu/nnacl_c/kernel/f16/reduce_f16.c | 118 + .../cpu/nnacl_c/kernel/f16/reduce_f16.h | 27 + .../kernel/cpu/nnacl_c/kernel/f16/stack_f16.c | 96 + .../kernel/cpu/nnacl_c/kernel/f16/stack_f16.h | 32 + .../litert/kernel/cpu/nnacl_c/kernel/fill.c | 102 + .../litert/kernel/cpu/nnacl_c/kernel/fill.h | 36 + .../cpu/nnacl_c/kernel/fullconnection.c | 81 + .../cpu/nnacl_c/kernel/fullconnection.h | 25 + .../cpu/nnacl_c/kernel/fused_batch_norm.c | 327 ++ .../cpu/nnacl_c/kernel/fused_batch_norm.h | 37 + .../litert/kernel/cpu/nnacl_c/kernel/gather.c | 241 + .../litert/kernel/cpu/nnacl_c/kernel/gather.h | 46 + .../kernel/cpu/nnacl_c/kernel/gather_d.c | 124 + .../kernel/cpu/nnacl_c/kernel/gather_d.h | 25 + .../kernel/cpu/nnacl_c/kernel/gather_nd.c | 168 + .../kernel/cpu/nnacl_c/kernel/gather_nd.h | 35 + .../cpu/nnacl_c/kernel/group_convolution.c | 419 ++ .../cpu/nnacl_c/kernel/group_convolution.h | 49 + .../kernel/cpu/nnacl_c/kernel/group_norm.c | 122 + .../kernel/cpu/nnacl_c/kernel/group_norm.h | 31 + .../kernel/cpu/nnacl_c/kernel/init_exec_env.c | 51 + .../kernel/cpu/nnacl_c/kernel/init_exec_env.h | 27 + .../cpu/nnacl_c/kernel/init_vs_kernels.c | 357 ++ .../cpu/nnacl_c/kernel/init_vs_kernels.h | 20 + .../kernel/cpu/nnacl_c/kernel/layer_norm.c | 130 + .../kernel/cpu/nnacl_c/kernel/layer_norm.h | 49 + .../cpu/nnacl_c/kernel/local_response_norm.c | 77 + .../cpu/nnacl_c/kernel/local_response_norm.h | 30 + .../kernel/cpu/nnacl_c/kernel/log_softmax.c | 120 + .../kernel/cpu/nnacl_c/kernel/log_softmax.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/matmul.c | 176 + .../litert/kernel/cpu/nnacl_c/kernel/matmul.h | 25 + .../kernel/cpu/nnacl_c/kernel/matmul_arm32.c | 110 + .../kernel/cpu/nnacl_c/kernel/matmul_arm32.h | 28 + .../kernel/cpu/nnacl_c/kernel/matmul_arm64.c | 214 + .../kernel/cpu/nnacl_c/kernel/matmul_arm64.h | 28 + .../kernel/cpu/nnacl_c/kernel/matmul_avx.c | 169 + .../kernel/cpu/nnacl_c/kernel/matmul_avx.h | 28 + .../kernel/cpu/nnacl_c/kernel/matmul_avx512.c | 708 +++ .../kernel/cpu/nnacl_c/kernel/matmul_avx512.h | 27 + .../kernel/cpu/nnacl_c/kernel/matmul_base.c | 676 +++ .../kernel/cpu/nnacl_c/kernel/matmul_base.h | 35 + .../kernel/cpu/nnacl_c/kernel/matmul_create.c | 82 + .../kernel/cpu/nnacl_c/kernel/matmul_create.h | 24 + .../kernel/cpu/nnacl_c/kernel/matmul_sse.c | 110 + .../kernel/cpu/nnacl_c/kernel/matmul_sse.h | 27 + .../kernel/cpu/nnacl_c/kernel/matmul_struct.h | 133 + .../kernel/cpu/nnacl_c/kernel/nllloss.c | 63 + .../kernel/cpu/nnacl_c/kernel/nllloss.h | 32 + .../cpu/nnacl_c/kernel/non_max_suppression.c | 126 + .../cpu/nnacl_c/kernel/non_max_suppression.h | 34 + .../kernel/cpu/nnacl_c/kernel/non_zero.c | 69 + .../kernel/cpu/nnacl_c/kernel/non_zero.h | 30 + .../kernel/cpu/nnacl_c/kernel/one_hot.c | 193 + .../kernel/cpu/nnacl_c/kernel/one_hot.h | 37 + .../kernel/cpu/nnacl_c/kernel/ones_like.c | 67 + .../kernel/cpu/nnacl_c/kernel/ones_like.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/pad.c | 406 ++ .../litert/kernel/cpu/nnacl_c/kernel/pad.h | 51 + .../kernel/cpu/nnacl_c/kernel/pooling.c | 159 + .../kernel/cpu/nnacl_c/kernel/pooling.h | 54 + .../litert/kernel/cpu/nnacl_c/kernel/pow.c | 79 + .../litert/kernel/cpu/nnacl_c/kernel/pow.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/prelu.c | 111 + .../litert/kernel/cpu/nnacl_c/kernel/prelu.h | 34 + .../kernel/cpu/nnacl_c/kernel/prior_box.c | 190 + .../kernel/cpu/nnacl_c/kernel/prior_box.h | 36 + .../kernel/cpu/nnacl_c/kernel/ragged_range.c | 74 + .../kernel/cpu/nnacl_c/kernel/ragged_range.h | 35 + .../litert/kernel/cpu/nnacl_c/kernel/range.c | 74 + .../litert/kernel/cpu/nnacl_c/kernel/range.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/rank.c | 44 + .../litert/kernel/cpu/nnacl_c/kernel/rank.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/reduce.c | 434 ++ .../litert/kernel/cpu/nnacl_c/kernel/reduce.h | 72 + .../kernel/cpu/nnacl_c/kernel/reshape.c | 96 + .../kernel/cpu/nnacl_c/kernel/reshape.h | 34 + .../kernel/cpu/nnacl_c/kernel/reverse.c | 166 + .../kernel/cpu/nnacl_c/kernel/reverse.h | 36 + .../litert/kernel/cpu/nnacl_c/kernel/scale.c | 333 ++ .../litert/kernel/cpu/nnacl_c/kernel/scale.h | 41 + .../litert/kernel/cpu/nnacl_c/kernel/shape.c | 51 + .../litert/kernel/cpu/nnacl_c/kernel/shape.h | 31 + .../litert/kernel/cpu/nnacl_c/kernel/size.c | 44 + .../litert/kernel/cpu/nnacl_c/kernel/size.h | 30 + .../litert/kernel/cpu/nnacl_c/kernel/slice.c | 76 + .../litert/kernel/cpu/nnacl_c/kernel/slice.h | 36 + .../kernel/cpu/nnacl_c/kernel/softmax.c | 157 + .../kernel/cpu/nnacl_c/kernel/softmax.h | 39 + .../litert/kernel/cpu/nnacl_c/kernel/splice.c | 79 + .../litert/kernel/cpu/nnacl_c/kernel/splice.h | 30 + .../litert/kernel/cpu/nnacl_c/kernel/stack.c | 138 + .../litert/kernel/cpu/nnacl_c/kernel/stack.h | 41 + .../kernel/cpu/nnacl_c/kernel/strided_slice.c | 278 ++ .../kernel/cpu/nnacl_c/kernel/strided_slice.h | 47 + .../litert/kernel/cpu/nnacl_c/kernel/tile.c | 182 + .../litert/kernel/cpu/nnacl_c/kernel/tile.h | 48 + .../kernel/cpu/nnacl_c/kernel/transpose.c | 358 ++ .../kernel/cpu/nnacl_c/kernel/transpose.h | 49 + .../litert/kernel/cpu/nnacl_c/kernel/tril.c | 89 + .../litert/kernel/cpu/nnacl_c/kernel/tril.h | 32 + .../litert/kernel/cpu/nnacl_c/kernel/triu.c | 89 + .../litert/kernel/cpu/nnacl_c/kernel/triu.h | 32 + .../litert/kernel/cpu/nnacl_c/kernel/unique.c | 66 + .../litert/kernel/cpu/nnacl_c/kernel/unique.h | 32 + .../litert/kernel/cpu/nnacl_c/kernel/where.c | 298 ++ .../litert/kernel/cpu/nnacl_c/kernel/where.h | 44 + .../kernel/cpu/nnacl_c/kernel/zeros_like.c | 43 + .../kernel/cpu/nnacl_c/kernel/zeros_like.h | 31 + .../kernel/cpu/nnacl_c/l2_norm_parameter.h | 41 + .../kernel/cpu/nnacl_c/layer_norm_parameter.h | 37 + .../nnacl_c/local_response_norm_parameter.h | 31 + .../cpu/nnacl_c/lsh_projection_parameter.h | 35 + .../kernel/cpu/nnacl_c/lstm_parameter.h | 44 + .../kernel/cpu/nnacl_c/matmul_parameter.h | 96 + .../litert/kernel/cpu/nnacl_c/mul_parameter.h | 32 + .../kernel/cpu/nnacl_c/nllloss_parameter.h | 27 + .../litert/kernel/cpu/nnacl_c/nnacl_common.c | 57 + .../litert/kernel/cpu/nnacl_c/nnacl_common.h | 109 + .../litert/kernel/cpu/nnacl_c/nnacl_utils.c | 27 + .../litert/kernel/cpu/nnacl_c/nnacl_utils.h | 39 + .../nnacl_c/non_max_suppression_parameter.h | 28 + .../kernel/cpu/nnacl_c/one_hot_parameter.h | 26 + .../src/litert/kernel/cpu/nnacl_c/op_base.h | 802 +++ .../cpu/nnacl_c/op_simd_header_file.h.in | 36 + .../cpu/nnacl_c/optimize/CMakeLists.txt | 62 + .../src/litert/kernel/cpu/nnacl_c/pack.h | 23 + .../litert/kernel/cpu/nnacl_c/pad_parameter.h | 38 + .../cpu/nnacl_c/partial_fusion_parameter.h | 29 + .../kernel/cpu/nnacl_c/pooling_parameter.h | 55 + .../litert/kernel/cpu/nnacl_c/pow_parameter.h | 37 + .../kernel/cpu/nnacl_c/predict_parameter.h | 32 + .../kernel/cpu/nnacl_c/prelu_parameter.h | 26 + .../kernel/cpu/nnacl_c/prior_box_parameter.h | 39 + .../kernel/cpu/nnacl_c/random_parameter.h | 34 + .../kernel/cpu/nnacl_c/range_parameter.h | 29 + .../kernel/cpu/nnacl_c/reduce_parameter.h | 29 + .../cpu/nnacl_c/reduce_scatter_parameter.h | 31 + .../kernel/cpu/nnacl_c/reshape_parameter.h | 40 + .../kernel/cpu/nnacl_c/resize_parameter.h | 37 + .../kernel/cpu/nnacl_c/reverse_parameter.h | 30 + .../cpu/nnacl_c/reverse_sequence_parameter.h | 45 + .../kernel/cpu/nnacl_c/scale_parameter.h | 39 + .../cpu/nnacl_c/scatter_elements_parameter.h | 25 + .../kernel/cpu/nnacl_c/scatter_nd_parameter.h | 29 + .../cpu/nnacl_c/sequence_unstack_parameter.h | 34 + .../kernel/cpu/nnacl_c/sigmoid_parameter.h | 41 + .../kernel/cpu/nnacl_c/skip_gram_parameter.h | 30 + .../kernel/cpu/nnacl_c/slice_parameter.h | 35 + .../kernel/cpu/nnacl_c/softmax_parameter.h | 27 + .../cpu/nnacl_c/space_to_depth_parameter.h | 27 + .../cpu/nnacl_c/sparse_to_dense_parameter.h | 32 + .../kernel/cpu/nnacl_c/splice_parameter.h | 28 + .../kernel/cpu/nnacl_c/split_parameter.h | 63 + .../kernel/cpu/nnacl_c/squeeze_parameter.h | 46 + .../kernel/cpu/nnacl_c/stack_parameter.h | 27 + .../cpu/nnacl_c/strided_slice_parameter.h | 43 + .../cpu/nnacl_c/tensor_array_parameter.h | 29 + .../src/litert/kernel/cpu/nnacl_c/tensor_c.h | 31 + .../kernel/cpu/nnacl_c/tensor_c_utils.c | 439 ++ .../kernel/cpu/nnacl_c/tensor_c_utils.h | 47 + .../litert/kernel/cpu/nnacl_c/tensorlist_c.h | 41 + .../kernel/cpu/nnacl_c/tensorlist_c_utils.c | 82 + .../kernel/cpu/nnacl_c/tensorlist_c_utils.h | 38 + .../kernel/cpu/nnacl_c/tensorlist_parameter.h | 32 + .../kernel/cpu/nnacl_c/tile_parameter.h | 28 + .../kernel/cpu/nnacl_c/transpose_parameter.h | 44 + .../kernel/cpu/nnacl_c/triu_tril_parameter.h | 31 + .../kernel/cpu/nnacl_c/unsqueeze_parameter.h | 48 + .../kernel/cpu/nnacl_c/unstack_parameter.h | 34 + .../kernel/cpu/nnacl_c/upsample_parameter.h | 29 + .../kernel/cpu/nnacl_c/where_parameter.h | 25 + .../litert/kernel/cpu/string/lsh_projection.h | 2 +- .../src/litert/kernel/cpu/string/predict.h | 2 +- .../src/litert/kernel/cpu/string/skip_gram.h | 2 +- .../kernel/gpu/opencl/opencl_executor.cc | 2 +- .../litert/kernel/opencl/kernel/activation.h | 2 +- .../litert/kernel/opencl/kernel/argminmax.h | 4 +- .../litert/kernel/opencl/kernel/arithmetic.cc | 2 +- .../kernel/opencl/kernel/arithmetic_self.h | 2 +- .../kernel/opencl/kernel/batch_to_space_nd.h | 2 +- .../litert/kernel/opencl/kernel/batchnorm.cc | 2 +- .../litert/kernel/opencl/kernel/batchnorm.h | 2 +- .../src/litert/kernel/opencl/kernel/concat.h | 2 +- .../src/litert/kernel/opencl/kernel/conv2d.h | 2 +- .../kernel/opencl/kernel/conv2d_transpose.cc | 2 +- .../kernel/opencl/kernel/conv2d_transpose.h | 2 +- .../src/litert/kernel/opencl/kernel/crop.h | 2 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 4 +- .../kernel/opencl/kernel/depthwise_conv2d.h | 2 +- .../src/litert/kernel/opencl/kernel/fill.h | 2 +- .../kernel/opencl/kernel/fullconnection.cc | 2 +- .../kernel/opencl/kernel/fullconnection.h | 2 +- .../kernel/opencl/kernel/fusion_eltwise.cc | 6 +- .../kernel/opencl/kernel/fusion_eltwise.h | 2 +- .../src/litert/kernel/opencl/kernel/gather.h | 2 +- .../opencl/kernel/int8/arithmetic_int8.cc | 4 +- .../litert/kernel/opencl/kernel/layer_norm.cc | 2 +- .../src/litert/kernel/opencl/kernel/matmul.h | 2 +- .../src/litert/kernel/opencl/kernel/one_hot.h | 2 +- .../src/litert/kernel/opencl/kernel/pad.h | 2 +- .../litert/kernel/opencl/kernel/pooling2d.h | 2 +- .../src/litert/kernel/opencl/kernel/power.h | 2 +- .../src/litert/kernel/opencl/kernel/prelu.cc | 4 +- .../src/litert/kernel/opencl/kernel/reduce.h | 2 +- .../src/litert/kernel/opencl/kernel/resize.h | 2 +- .../src/litert/kernel/opencl/kernel/scale.cc | 2 +- .../src/litert/kernel/opencl/kernel/scale.h | 2 +- .../litert/kernel/opencl/kernel/softmax.cc | 2 +- .../src/litert/kernel/opencl/kernel/softmax.h | 2 +- .../kernel/opencl/kernel/space_to_batch_nd.h | 2 +- .../kernel/opencl/kernel/space_to_depth.h | 2 +- .../kernel/opencl/kernel/sparse_to_dense.h | 2 +- .../src/litert/kernel/opencl/kernel/split.h | 2 +- .../src/litert/kernel/opencl/kernel/stack.h | 2 +- .../kernel/opencl/kernel/strided_slice.h | 2 +- .../litert/kernel/opencl/kernel/transpose.h | 2 +- .../litert/kernel/opencl/kernel/winograd.cc | 4 +- .../src/litert/kernel/opencl/opencl_fusion.cc | 14 +- .../src/litert/kernel/opencl/opencl_kernel.h | 2 +- .../src/litert/kernel/opencl/utils.h | 2 +- mindspore-lite/src/litert/kernel_exec_util.cc | 2 +- mindspore-lite/src/litert/kernel_registry.cc | 2 +- mindspore-lite/src/litert/lite_kernel.h | 2 +- mindspore-lite/src/litert/lite_model.h | 2 +- mindspore-lite/src/litert/mindrt_executor.cc | 4 +- .../litert/pass/format_pass/format_pass.cc | 2 +- .../pass/format_pass/insert_transpose.cc | 2 +- .../src/litert/pass/format_pass/pass_utils.cc | 4 +- .../pass/format_pass/transpose_strategy.cc | 24 +- .../cast_gather_reduce_fusion_pass.cc | 2 +- .../pass/online_fusion/online_fusion_pass.cc | 6 +- .../pass/online_fusion/online_fusion_pass.h | 4 +- .../reduce_concat_fusion_pass.cc | 4 +- .../online_fusion/reduce_concat_fusion_pass.h | 2 +- .../split_reduce_concat_fusion_pass.cc | 6 +- .../split_reduce_concat_fusion_pass.h | 2 +- .../src/litert/runtime_packed_node_pass.cc | 6 +- mindspore-lite/src/litert/runtime_pass.cc | 2 +- .../src/litert/runtime_shape_fusion_pass.cc | 2 +- mindspore-lite/src/litert/scheduler.cc | 4 +- .../src/litert/schema_tensor_wrapper.cc | 2 +- mindspore-lite/src/litert/sub_graph_split.cc | 4 +- mindspore-lite/src/litert/sub_graph_split.h | 2 +- .../src/litert/thread_cost_model.cc | 2 +- mindspore-lite/src/litert/thread_cost_model.h | 2 +- mindspore-lite/src/litert/weight_decoder.cc | 2 +- mindspore-lite/src/litert/weight_decoder.h | 4 +- mindspore-lite/src/tensor.h | 4 +- mindspore-lite/src/tensorlist.cc | 2 +- mindspore-lite/src/tensorlist.h | 2 +- mindspore-lite/src/train/opt_allocator.cc | 2 +- .../train/optimizer/fusion/gru_fusion_pass.cc | 2 +- mindspore-lite/src/train/train_loop.cc | 2 +- .../src/train/train_populate_parameter.cc | 30 +- mindspore-lite/test/CMakeLists.txt | 2 +- mindspore-lite/test/common/common_test.h | 4 +- .../test/ut/nnacl/infer/adam_infer_test.cc | 2 +- .../infer/adam_weight_decay_infer_test.cc | 2 +- .../test/ut/nnacl/infer/addn_infer_test.cc | 2 +- .../nnacl/infer/apply_momentum_infer_test.cc | 2 +- .../test/ut/nnacl/infer/argmax_infer_test.cc | 2 +- .../test/ut/nnacl/infer/argmin_infer_test.cc | 2 +- .../infer/arithmetic_compare_infer_test.cc | 2 +- .../ut/nnacl/infer/arithmetic_infer_test.cc | 2 +- .../ut/nnacl/infer/assign_add_infer_test.cc | 2 +- .../test/ut/nnacl/infer/assign_infer_test.cc | 2 +- .../infer/audio_spectrogram_infer_test.cc | 2 +- .../nnacl/infer/batch_to_space_infer_test.cc | 2 +- .../ut/nnacl/infer/bias_grad_infer_test.cc | 2 +- .../infer/binary_cross_entropy_infer_test.cc | 2 +- .../test/ut/nnacl/infer/bn_grad_infer_test.cc | 2 +- .../ut/nnacl/infer/broadcast_to_infer_test.cc | 2 +- .../test/ut/nnacl/infer/cast_infer_test.cc | 2 +- .../test/ut/nnacl/infer/concat_infer_test.cc | 2 +- .../infer/constant_of_shape_infer_test.cc | 2 +- .../infer/conv2d_grad_filter_infer_test.cc | 2 +- .../infer/conv2d_grad_input_infer_test.cc | 2 +- .../test/ut/nnacl/infer/conv2d_infer_test.cc | 2 +- .../nnacl/infer/crop_and_resize_infer_test.cc | 2 +- .../test/ut/nnacl/infer/crop_infer_test.cc | 2 +- .../test/ut/nnacl/infer/cumsum_infer_test.cc | 4 +- .../custom_extract_features_infer_test.cc | 2 +- .../infer/custom_normalize_infer_test.cc | 2 +- .../nnacl/infer/custom_predict_infer_test.cc | 2 +- .../ut/nnacl/infer/deconv2d_infer_test.cc | 2 +- .../nnacl/infer/depth_to_space_infer_test.cc | 2 +- .../infer/depthwise_conv2d_infer_test.cc | 2 +- .../detection_post_process_infer_test.cc | 2 +- .../ut/nnacl/infer/dropout_grad_infer_test.cc | 2 +- .../infer/embedding_lookup_infer_test.cc | 2 +- .../ut/nnacl/infer/expand_dims_infer_test.cc | 2 +- .../ut/nnacl/infer/fft_imag_infer_test.cc | 2 +- .../test/ut/nnacl/infer/fill_infer_test.cc | 4 +- .../ut/nnacl/infer/flatten_grad_infer_test.cc | 2 +- .../test/ut/nnacl/infer/flatten_infer_test.cc | 2 +- .../nnacl/infer/full_connection_infer_test.cc | 2 +- .../nnacl/infer/fused_batchnorm_infer_test.cc | 2 +- .../test/ut/nnacl/infer/gather_infer_test.cc | 2 +- .../ut/nnacl/infer/gather_nd_infer_test.cc | 4 +- .../group_conv2d_grad_input_infer_test.cc | 2 +- .../test/ut/nnacl/infer/gru_infer_test.cc | 2 +- .../infer/hashtable_lookup_infer_test.cc | 2 +- .../infer/invert_permutation_infer_test.cc | 2 +- .../ut/nnacl/infer/layer_norm_infer_test.cc | 2 +- .../nnacl/infer/lsh_projection_infer_test.cc | 2 +- .../test/ut/nnacl/infer/lstm_infer_test.cc | 2 +- .../test/ut/nnacl/infer/matmul_infer_test.cc | 2 +- .../ut/nnacl/infer/max_min_grad_infer_test.cc | 4 +- .../test/ut/nnacl/infer/mfcc_infer_test.cc | 2 +- .../ut/nnacl/infer/nllloss_grad_infer_test.cc | 2 +- .../test/ut/nnacl/infer/nllloss_infer_test.cc | 2 +- .../test/ut/nnacl/infer/one_hot_infer_test.cc | 2 +- .../test/ut/nnacl/infer/pad_infer_test.cc | 2 +- .../ut/nnacl/infer/pooling_grad_infer_test.cc | 2 +- .../test/ut/nnacl/infer/pooling_infer_test.cc | 2 +- .../test/ut/nnacl/infer/power_infer_test.cc | 2 +- .../infer/quant_dtype_cast_infer_test.cc | 2 +- .../random_standard_normal_infer_test.cc | 2 +- .../test/ut/nnacl/infer/range_infer_test.cc | 4 +- .../test/ut/nnacl/infer/rank_infer_test.cc | 2 +- .../test/ut/nnacl/infer/reduce_infer_test.cc | 2 +- .../test/ut/nnacl/infer/reshape_infer_test.cc | 4 +- .../test/ut/nnacl/infer/resize_infer_test.cc | 2 +- .../test/ut/nnacl/infer/rfft_infer_test.cc | 2 +- .../ut/nnacl/infer/roi_pooling_infer_test.cc | 2 +- .../nnacl/infer/scatter_nd_add_infer_test.cc | 4 +- .../ut/nnacl/infer/scatter_nd_infer_test.cc | 2 +- .../test/ut/nnacl/infer/select_infer_test.cc | 2 +- .../test/ut/nnacl/infer/sgd_infer_test.cc | 2 +- .../test/ut/nnacl/infer/shape_infer_test.cc | 2 +- .../test/ut/nnacl/infer/size_infer_test.cc | 2 +- .../ut/nnacl/infer/skip_gram_infer_test.cc | 2 +- .../test/ut/nnacl/infer/slice_infer_test.cc | 2 +- .../infer/softmax_cross_entropy_infer_test.cc | 2 +- .../test/ut/nnacl/infer/softmax_infer_test.cc | 2 +- .../nnacl/infer/space_to_batch_infer_test.cc | 2 +- .../infer/space_to_batch_nd_infer_test.cc | 2 +- .../nnacl/infer/space_to_depth_infer_test.cc | 2 +- .../nnacl/infer/sparse_to_dense_infer_test.cc | 2 +- .../test/ut/nnacl/infer/split_infer_test.cc | 2 +- .../test/ut/nnacl/infer/squeeze_infer_test.cc | 2 +- .../test/ut/nnacl/infer/stack_infer_test.cc | 2 +- .../nnacl/infer/strided_slice_infer_test.cc | 2 +- .../infer/tensorlist_fromtensor_infer_test.cc | 2 +- .../infer/tensorlist_getitem_infer_test.cc | 2 +- .../infer/tensorlist_reserve_infer_test.cc | 2 +- .../infer/tensorlist_setitem_infer_test.cc | 2 +- .../infer/tensorlist_stack_infer_test.cc | 2 +- .../test/ut/nnacl/infer/tile_infer_test.cc | 6 +- .../test/ut/nnacl/infer/topk_infer_test.cc | 2 +- .../ut/nnacl/infer/transpose_infer_test.cc | 2 +- .../test/ut/nnacl/infer/unique_infer_test.cc | 2 +- .../infer/unsorted_segment_sum_infer_test.cc | 2 +- .../ut/nnacl/infer/unsqueeze_infer_test.cc | 4 +- .../test/ut/nnacl/infer/unstack_infer_test.cc | 2 +- .../test/ut/nnacl/infer/where_infer_test.cc | 2 +- .../nnacl/int8/quant_dtype_cast_int8_test.cc | 4 +- .../test/ut/nnacl/kernel/cast_test.cc | 6 +- .../runtime/kernel/arm/common/pack_tests.cc | 8 +- .../kernel/arm/common/strided_slice_tests.cc | 2 +- .../fp16_grad/activation_grad_fp16_test.cc | 2 +- .../arithmetic_fp16_self_grad_tests.cc | 2 +- .../arm/fp32-sparsity/matmul_fp32_tests.cc | 2 +- .../kernel/arm/fp32/activation_fp32_test.cc | 2 +- .../arm/fp32/batch_to_space_fp32_test.cc | 6 +- .../kernel/arm/fp32/batchnorm_fp32_tests.cc | 2 +- .../kernel/arm/fp32/conv1x1_fp32_tests.cc | 2 +- .../runtime/kernel/arm/fp32/crop_fp32_test.cc | 2 +- .../runtime/kernel/arm/fp32/cumsum_tests.cc | 2 +- .../arm/fp32/deconvolution_fp32_tests.cc | 4 +- .../arm/fp32/depth_to_space_fp32_test.cc | 8 +- .../arm/fp32/embedding_lookup_fp32_test.cc | 2 +- .../arm/fp32/fullconnection_fp32_tests.cc | 2 +- .../kernel/arm/fp32/logicalor_fp32_test.cc | 2 +- .../arm/fp32/lsh_projection_fp32_tests.cc | 2 +- .../kernel/arm/fp32/lstm_fp32_tests.cc | 2 +- .../kernel/arm/fp32/matmul_fp32_tests.cc | 4 +- .../kernel/arm/fp32/nllloss_fp32_test.cc | 2 +- .../kernel/arm/fp32/one_hot_fp32_test.cc | 2 +- .../kernel/arm/fp32/power_fp32_tests.cc | 2 +- .../arm/fp32/ragged_range_fp32_tests.cc | 2 +- .../kernel/arm/fp32/reduce_fp32_tests.cc | 2 +- .../arm/fp32/resize_bilinear_fp32_tests.cc | 2 +- .../resize_nearest_neighbor_fp32_tests.cc | 2 +- .../arm/fp32/reverse_sequence_fp32_tests.cc | 2 +- .../kernel/arm/fp32/scale_fp32_tests.cc | 6 +- .../arm/fp32/scatter_nd_add_fp32_test.cc | 2 +- .../kernel/arm/fp32/scatter_nd_fp32_tests.cc | 2 +- .../runtime/kernel/arm/fp32/skip_gram_fp32.cc | 2 +- .../runtime/kernel/arm/fp32/softmax_tests.cc | 2 +- .../arm/fp32/space_to_batch_fp32_tests.cc | 2 +- .../arm/fp32/space_to_depth_fp32_tests.cc | 4 +- .../arm/fp32/sparse_to_dense_fp32_tests.cc | 2 +- .../kernel/arm/fp32/stack_fp32_test.cc | 2 +- .../kernel/arm/fp32/tile_fp32_tests.cc | 2 +- .../kernel/arm/fp32/topk_fp32_tests.cc | 2 +- .../kernel/arm/fp32/transpose_fp32_tests.cc | 4 +- .../kernel/arm/fp32/uniform_real_fp32_test.cc | 2 +- .../kernel/arm/fp32/unique_fp32_tests.cc | 2 +- .../kernel/arm/fp32/unstack_fp32_tests.cc | 2 +- .../fp32_grad/activation_grad_fp32_tests.cc | 2 +- .../fp32_grad/arithmetic_grad_fp32_tests.cc | 2 +- .../kernel/arm/fp32_grad/bn_grad_fp32_test.cc | 6 +- .../fp32_grad/convolution_grad_fp32_tests.cc | 2 +- .../deconvolution_grad_fp32_tests.cc | 2 +- .../arm/fp32_grad/pooling_grad_fp32_tests.cc | 4 +- .../arm/fp32_grad/softmax_grad_fp32_tests.cc | 2 +- .../arm/int8/arithmetic_self_int8_tests.cc | 2 +- .../kernel/arm/int8/batchnorm_int8_test.cc | 4 +- .../kernel/arm/int8/concat_int8_tests.cc | 2 +- .../kernel/arm/int8/conv_1x1_int8_tests.cc | 4 +- .../kernel/arm/int8/crop_int8_tests.cc | 2 +- .../kernel/arm/int8/deconv_int8_tests.cc | 6 +- .../arm/int8/fullconnection_int8_tests.cc | 4 +- .../kernel/arm/int8/gatherNd_int8_test.cc | 6 +- .../kernel/arm/int8/gather_int8_test.cc | 4 +- .../kernel/arm/int8/hswish_int8_tests.cc | 2 +- .../kernel/arm/int8/l2_norm_int8_tests.cc | 2 +- .../kernel/arm/int8/matmul_int8_tests.cc | 6 +- .../runtime/kernel/arm/int8/mul_int8_tests.cc | 4 +- .../runtime/kernel/arm/int8/pad_int8_tests.cc | 2 +- .../kernel/arm/int8/power_int8_tests.cc | 2 +- .../kernel/arm/int8/prelu_int8_tests.cc | 2 +- .../kernel/arm/int8/quant_dtype_cast_tests.cc | 2 +- .../kernel/arm/int8/reduce_int8_tests.cc | 2 +- .../kernel/arm/int8/reshape_int8_tests.cc | 2 +- .../arm/int8/resize_bilinear_int8_tests.cc | 2 +- .../resize_nearest_neighbor_int8_tests.cc | 2 +- .../src/runtime/kernel/arm/int8/scale_int8.cc | 2 +- .../kernel/arm/int8/sigmoid_int8_tests.cc | 2 +- .../kernel/arm/int8/softmax_int8_tests.cc | 2 +- .../arm/int8/space_to_batch_int8_tests.cc | 2 +- .../kernel/arm/int8/split_int8_tests.cc | 2 +- .../kernel/arm/int8/squeeze_int8_tests.cc | 2 +- .../kernel/arm/int8/topk_int8_tests.cc | 2 +- .../kernel/arm/int8/unsqueeze_int8_tests.cc | 2 +- .../runtime/kernel/arm/string/normalize.cc | 2 +- .../runtime/kernel/cuda/batchtospace_tests.cc | 2 +- .../runtime/kernel/opencl/activation_tests.cc | 2 +- .../runtime/kernel/opencl/argminmax_tests.cc | 2 +- .../kernel/opencl/arithmetic_self_tests.cc | 2 +- .../runtime/kernel/opencl/arithmetic_tests.cc | 2 +- .../kernel/opencl/batch_to_space_nd_tests.cc | 2 +- .../runtime/kernel/opencl/batchnorm_tests.cc | 2 +- .../ut/src/runtime/kernel/opencl/common.cc | 2 +- .../ut/src/runtime/kernel/opencl/common.h | 2 +- .../src/runtime/kernel/opencl/concat_tests.cc | 2 +- .../src/runtime/kernel/opencl/conv2d_tests.cc | 2 +- .../kernel/opencl/conv2d_transpose_tests.cc | 2 +- .../src/runtime/kernel/opencl/crop_tests.cc | 2 +- .../kernel/opencl/depthwise_conv2d_tests.cc | 2 +- .../kernel/opencl/fullconnection_tests.cc | 2 +- .../src/runtime/kernel/opencl/gather_tests.cc | 2 +- .../runtime/kernel/opencl/layer_norm_tests.cc | 2 +- .../src/runtime/kernel/opencl/matmul_tests.cc | 2 +- .../runtime/kernel/opencl/one_hot_tests.cc | 2 +- .../ut/src/runtime/kernel/opencl/pad_tests.cc | 2 +- .../runtime/kernel/opencl/pooling_tests.cc | 2 +- .../src/runtime/kernel/opencl/prelu_tests.cc | 2 +- .../src/runtime/kernel/opencl/reduce_tests.cc | 2 +- .../runtime/kernel/opencl/reshape_tests.cc | 2 +- .../src/runtime/kernel/opencl/resize_tests.cc | 2 +- .../src/runtime/kernel/opencl/scale_tests.cc | 2 +- .../src/runtime/kernel/opencl/slice_tests.cc | 2 +- .../runtime/kernel/opencl/softmax_tests.cc | 2 +- .../kernel/opencl/space_to_batch_nd_tests.cc | 2 +- .../kernel/opencl/space_to_depth_tests.cc | 6 +- .../kernel/opencl/sparse_to_dense_tests.cc | 4 +- .../src/runtime/kernel/opencl/split_tests.cc | 2 +- .../src/runtime/kernel/opencl/stack_tests.cc | 2 +- .../kernel/opencl/strided_slice_tests.cc | 2 +- .../runtime/kernel/opencl/transpose_tests.cc | 2 +- .../test/ut/src/runtime/runtime_pass_tests.cc | 8 +- .../activation_fusion_inout_test.cc | 2 +- .../add_concat_act_fusion_inout_test.cc | 2 +- .../conv_act_fusion_inout_test.cc | 2 +- .../conv_bias_fusion_inout_test.cc | 2 +- .../conv_fusion_inout_test.cc | 2 +- .../fusion_inout_test/fusion_inout_test.cc | 2 +- .../matmul_act_fusion_inout_test.cc | 2 +- .../matmul_fusion_inout_test.cc | 2 +- .../matmul_fusion_inout_test.h | 2 +- .../matmul_mul_fusion_inout_test.cc | 2 +- .../trans_matmul_fusion_inout_test.cc | 2 +- mindspore-lite/tools/benchmark/CMakeLists.txt | 4 +- .../tools/benchmark/benchmark_base.h | 2 +- .../tools/benchmark/benchmark_unified_api.cc | 2 +- .../tools/benchmark_train/CMakeLists.txt | 2 +- .../tools/common/func_graph_subgraph.cc | 2 +- mindspore-lite/tools/common/graph_util.cc | 2 +- mindspore-lite/tools/common/graph_util.h | 2 +- .../tools/common/meta_graph_serializer.cc | 2 +- .../tools/common/meta_graph_utils.cc | 2 +- mindspore-lite/tools/common/node_util.cc | 2 +- mindspore-lite/tools/common/opengl_util.h | 2 +- .../tools/common/statistic_utils.cc | 2 +- mindspore-lite/tools/common/statistic_utils.h | 2 +- mindspore-lite/tools/common/tensor_util.cc | 2 +- mindspore-lite/tools/converter/CMakeLists.txt | 2 +- .../converter/adapter/acl/common/utils.cc | 2 +- .../adapter/acl/infer/custom_infer.cc | 2 +- .../acl/infer/flash_attention_infer.cc | 2 +- .../acl/infer/forward_rasterize_infer.cc | 2 +- .../adapter/acl/mapper/arithmetic_mapper.cc | 2 +- .../adapter/acl/mapper/cast_mapper.cc | 2 +- .../acl/mapper/constant_of_shape_mapper.cc | 2 +- .../mapper/conv2d_transpose_fusion_mapper.cc | 2 +- .../adapter/acl/mapper/conv_base_mapper.cc | 2 +- .../adapter/acl/mapper/gather_d_mapper.cc | 2 +- .../acl/mapper/gather_fusion_mapper.cc | 2 +- .../adapter/acl/mapper/gru_mapper.cc | 2 +- .../adapter/acl/mapper/lstm_mapper.cc | 2 +- .../acl/mapper/matmul_fusion_mapper.cc | 2 +- .../acl/mapper/maxpool_fusion_mapper.cc | 2 +- .../adapter/acl/mapper/onehot_mapper.cc | 2 +- .../adapter/acl/mapper/primitive_mapper.cc | 2 +- .../acl/mapper/quant_dtype_cast_mapper.cc | 2 +- .../adapter/acl/mapper/reshape_mapper.cc | 2 +- .../adapter/acl/mapper/resize_mapper.cc | 2 +- .../adapter/acl/mapper/stridedslice_mapper.cc | 2 +- .../adapter/acl/mapper/tile_fusion_mapper.cc | 2 +- .../adapter/acl/mapper/topk_fusion_mapper.cc | 2 +- .../adapter/acl/mapper/transpose_mapper.cc | 2 +- .../adapter/acl/mapper/upsample_mapper.cc | 2 +- .../adapter/acl/mapper/where_mapper.cc | 2 +- .../adapter/acl/src/acl_pass_impl.cc | 2 +- .../tools/converter/anf_transform.cc | 2 +- .../tools/converter/anf_transform_for_ge.cc | 2 +- .../config_parser/acl_option_param_parser.cc | 2 +- mindspore-lite/tools/converter/converter.cc | 2 +- .../tools/converter/converter_funcgraph.cc | 2 +- .../tools/converter/converter_packed_node.cc | 2 +- .../tools/converter/export_model.cc | 2 +- .../tools/converter/import/mindir_adjust.cc | 2 +- .../import/mindir_control_flow_adjust.cc | 2 +- .../converter/import/mindspore_importer.cc | 2 +- .../converter/import/primitive_adjust.cc | 2 +- .../import/remove_public_primitive.cc | 2 +- .../legacy_optimizer/fusion/fusion_pass.cc | 2 +- .../legacy_optimizer/fusion/fusion_pattern.h | 2 +- .../legacy_optimizer/graph/dtype_trans_pass.h | 2 +- .../legacy_optimizer/graph/infershape_pass.cc | 2 +- .../micro/coder/allocator/memory_manager.cc | 2 +- .../generator/component/common_component.cc | 2 +- .../component/const_blocks/debug_utils.cc | 2 +- .../generator/component/train_component.cc | 2 +- .../micro/coder/generator/generator.cc | 4 +- .../tools/converter/micro/coder/log.h | 2 +- .../coder/opcoders/base/conv2d_base_coder.cc | 4 +- .../coder/opcoders/base/conv2d_base_coder.h | 2 +- .../base/detection_post_process_base_coder.cc | 8 +- .../base/detection_post_process_base_coder.h | 2 +- .../coder/opcoders/base/dtype_cast_coder.cc | 2 +- .../coder/opcoders/base/dtype_cast_coder.h | 2 +- .../base/full_connection_base_coder.h | 2 +- .../opcoders/base/quant_dtype_cast_coder.cc | 2 +- .../opcoders/base/quant_dtype_cast_coder.h | 2 +- .../coder/opcoders/base/reduce_base_coder.h | 2 +- .../coder/opcoders/base/resize_base_coder.h | 2 +- .../coder/opcoders/base/softmax_base_coder.h | 4 +- .../coder/opcoders/base/stack_base_coder.cc | 2 +- .../coder/opcoders/base/stack_base_coder.h | 2 +- .../opcoders/base/strided_slice_base_coder.cc | 2 +- .../opcoders/base/strided_slice_base_coder.h | 2 +- .../base/strided_slice_dynamic_base_coder.cc | 2 +- .../base/strided_slice_dynamic_base_coder.h | 4 +- .../coder/opcoders/base/unstack_base_coder.cc | 2 +- .../coder/opcoders/base/unstack_base_coder.h | 4 +- .../opcoders/cmsis-nn/int8/add_int8_coder.cc | 4 +- .../cmsis-nn/int8/conv2d_base_coder.cc | 2 +- .../cmsis-nn/int8/conv2d_base_coder.h | 2 +- .../cmsis-nn/int8/conv2d_int8_coder.h | 2 +- .../cmsis-nn/int8/fullconnection_int8_coder.h | 2 +- .../opcoders/cmsis-nn/int8/mul_int8_coder.cc | 2 +- .../cmsis-nn/int8/pooling_int8_coder.h | 2 +- .../coder/opcoders/custom/custom_coder.cc | 4 +- .../fp16/activation_dynamic_fp16_coder.cc | 2 +- .../nnacl/fp16/activation_fp16_coder.cc | 2 +- .../fp16/arithmetic_dynamic_fp16_coder.cc | 4 +- .../fp16/arithmetic_dynamic_fp16_coder.h | 6 +- .../nnacl/fp16/arithmetic_fp16_coder.cc | 6 +- .../nnacl/fp16/arithmetic_self_fp16_coder.cc | 2 +- .../nnacl/fp16/arithmetic_self_fp16_coder.h | 2 +- .../nnacl/fp16/concat_dynamic_fp16_coder.cc | 2 +- .../nnacl/fp16/concat_dynamic_fp16_coder.h | 2 +- .../opcoders/nnacl/fp16/concat_fp16_coder.cc | 2 +- .../opcoders/nnacl/fp16/concat_fp16_coder.h | 2 +- .../conv2d_delegate_dynamic_fp16_coder.cc | 6 +- .../fp16/conv2d_delegate_dynamic_fp16_coder.h | 2 +- .../nnacl/fp16/conv2d_delegate_fp16_coder.cc | 6 +- .../nnacl/fp16/conv2d_delegate_fp16_coder.h | 2 +- .../fp16/conv_depthwise_3x3_fp16_coder.cc | 4 +- .../fp16/conv_depthwise_3x3_fp16_coder.h | 2 +- .../nnacl/fp16/conv_depthwise_fp16_coder.cc | 6 +- .../fp16/conv_depthwise_sw_fp16_coder.cc | 6 +- .../convolution_1x1_dynamic_fp16_coder.cc | 12 +- .../fp16/convolution_1x1_dynamic_fp16_coder.h | 4 +- .../nnacl/fp16/convolution_1x1_fp16_coder.cc | 12 +- .../nnacl/fp16/convolution_1x1_fp16_coder.h | 2 +- .../fp16/convolution_dynamic_fp16_coder.cc | 10 +- .../fp16/convolution_dynamic_fp16_coder.h | 2 +- .../nnacl/fp16/convolution_fp16_coder.cc | 10 +- .../nnacl/fp16/convolution_fp16_coder.h | 2 +- .../fp16/convolution_winograd_fp16_coder.cc | 10 +- .../fp16/convolution_winograd_fp16_coder.h | 2 +- .../nnacl/fp16/custom_gru_fp16_coder.cc | 2 +- .../nnacl/fp16/custom_gru_fp16_coder.h | 2 +- .../nnacl/fp16/deconv2d_fp16_coder.cc | 18 +- .../opcoders/nnacl/fp16/deconv2d_fp16_coder.h | 2 +- .../nnacl/fp16/layernorm_fp16_coder.cc | 2 +- .../nnacl/fp16/layernorm_fp16_coder.h | 2 +- .../opcoders/nnacl/fp16/lstm_fp16_coder.cc | 4 +- .../opcoders/nnacl/fp16/lstm_fp16_coder.h | 2 +- .../fp16/lstm_mindir_dynamic_fp16_coder.cc | 2 +- .../fp16/lstm_mindir_dynamic_fp16_coder.h | 2 +- .../fp16/matmul_dynamic_fp16_base_coder.cc | 4 +- .../fp16/matmul_dynamic_fp16_base_coder.h | 2 +- .../nnacl/fp16/matmul_dynamic_fp16_coder.h | 2 +- .../nnacl/fp16/matmul_fp16_base_coder.cc | 4 +- .../nnacl/fp16/matmul_fp16_base_coder.h | 2 +- .../opcoders/nnacl/fp16/matmul_fp16_coder.h | 2 +- .../nnacl/fp16/pooling_dynamic_fp16_coder.cc | 2 +- .../nnacl/fp16/pooling_dynamic_fp16_coder.h | 4 +- .../opcoders/nnacl/fp16/pooling_fp16_coder.cc | 4 +- .../opcoders/nnacl/fp16/reduce_fp16_coder.cc | 2 +- .../opcoders/nnacl/fp16/resize_fp16_coder.cc | 6 +- .../opcoders/nnacl/fp16/resize_fp16_coder.h | 2 +- .../nnacl/fp16/scale_dynamic_fp16_coder.cc | 4 +- .../nnacl/fp16/scale_dynamic_fp16_coder.h | 4 +- .../opcoders/nnacl/fp16/scale_fp16_coder.cc | 6 +- .../opcoders/nnacl/fp16/scale_fp16_coder.h | 4 +- .../nnacl/fp16/slice_dynamic_fp16_coder.cc | 2 +- .../nnacl/fp16/slice_dynamic_fp16_coder.h | 4 +- .../opcoders/nnacl/fp16/slice_fp16_coder.cc | 2 +- .../opcoders/nnacl/fp16/slice_fp16_coder.h | 2 +- .../nnacl/fp16/softmax_dynamic_fp16_coder.cc | 4 +- .../nnacl/fp16/softmax_dynamic_fp16_coder.h | 4 +- .../opcoders/nnacl/fp16/softmax_fp16_coder.cc | 4 +- .../fp16/transpose_dynamic_fp16_coder.cc | 6 +- .../nnacl/fp16/transpose_dynamic_fp16_coder.h | 2 +- .../nnacl/fp16/transpose_fp16_coder.cc | 6 +- .../nnacl/fp16/transpose_fp16_coder.h | 2 +- .../nnacl/fp32/activation_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/addn_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/affine_fp32_coder.h | 2 +- .../nnacl/fp32/arithmetic_fp32_coder.cc | 16 +- .../nnacl/fp32/arithmetic_fp32_coder.h | 2 +- .../nnacl/fp32/arithmetic_self_fp32_coder.cc | 4 +- .../nnacl/fp32/arithmetic_self_fp32_coder.h | 4 +- .../nnacl/fp32/assign_add_fp32_coder.h | 2 +- .../nnacl/fp32/batchnorm_fp32_coder.cc | 10 +- .../nnacl/fp32/batchnorm_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/biasadd_fp32_coder.cc | 13 +- .../opcoders/nnacl/fp32/biasadd_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/concat_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/concat_fp32_coder.h | 2 +- .../nnacl/fp32/conv2d_delegate_fp32_coder.cc | 4 +- .../nnacl/fp32/conv2d_delegate_fp32_coder.h | 2 +- .../fp32/convolution_depthwise_fp32_coder.cc | 4 +- .../nnacl/fp32/convolution_fp32_coder.cc | 10 +- .../nnacl/fp32/convolution_fp32_coder.h | 2 +- .../fp32/convolution_winograd_fp32_coder.cc | 12 +- .../fp32/convolution_winograd_fp32_coder.h | 2 +- .../nnacl/fp32/custom_gru_fp32_coder.cc | 4 +- .../nnacl/fp32/custom_gru_fp32_coder.h | 2 +- .../nnacl/fp32/deconv2d_fp32_coder.cc | 18 +- .../opcoders/nnacl/fp32/deconv2d_fp32_coder.h | 6 +- .../opcoders/nnacl/fp32/exp_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/exp_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/fill_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/fill_fp32_coder.h | 2 +- .../nnacl/fp32/gather_dynamic_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/gather_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/gather_fp32_coder.h | 2 +- .../nnacl/fp32/groupnorm_fp32_coder.cc | 6 +- .../nnacl/fp32/instance_norm_fp32_coder.cc | 2 +- .../nnacl/fp32/instance_norm_fp32_coder.h | 2 +- .../nnacl/fp32/layernorm_fp32_coder.cc | 2 +- .../nnacl/fp32/layernorm_fp32_coder.h | 4 +- .../opcoders/nnacl/fp32/lstm_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/lstm_fp32_coder.h | 2 +- .../nnacl/fp32/matmul_fp32_base_coder.cc | 6 +- .../nnacl/fp32/matmul_fp32_base_coder.h | 2 +- .../opcoders/nnacl/fp32/matmul_fp32_coder.h | 2 +- .../nnacl/fp32/ones_like_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/pad_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/pad_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/pooling_fp32_coder.cc | 6 +- .../opcoders/nnacl/fp32/pooling_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/power_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/power_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/prelu_fp32_coder.cc | 6 +- .../opcoders/nnacl/fp32/reduce_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/resize_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/resize_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/scale_fp32_coder.cc | 6 +- .../opcoders/nnacl/fp32/scale_fp32_coder.h | 4 +- .../opcoders/nnacl/fp32/slice_fp32_coder.cc | 6 +- .../opcoders/nnacl/fp32/slice_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/softmax_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/splice_fp32_coder.cc | 6 +- .../nnacl/fp32/split_dynamic_fp32_coder.cc | 4 +- .../nnacl/fp32/split_dynamic_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/split_fp32_coder.cc | 2 +- .../opcoders/nnacl/fp32/split_fp32_coder.h | 2 +- .../opcoders/nnacl/fp32/tile_fp32_coder.cc | 4 +- .../opcoders/nnacl/fp32/tile_fp32_coder.h | 2 +- .../fp32/transpose_dynamic_fp32_coder.cc | 6 +- .../nnacl/fp32/transpose_dynamic_fp32_coder.h | 2 +- .../nnacl/fp32/transpose_fp32_coder.cc | 6 +- .../nnacl/fp32/transpose_fp32_coder.h | 2 +- .../nnacl/fp32_grad/activation_grad_coder.cc | 4 +- .../opcoders/nnacl/fp32_grad/adam_coder.cc | 4 +- ...softmax_cross_entropy_with_logits_coder.cc | 6 +- .../softmax_cross_entropy_with_logits_coder.h | 2 +- .../nnacl/int8/activation_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/add_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/add_int8_coder.h | 2 +- .../opcoders/nnacl/int8/affine_int8_coder.h | 2 +- .../nnacl/int8/arithmetic_self_int8_coder.cc | 2 +- .../nnacl/int8/arithmetic_self_int8_coder.h | 6 +- .../nnacl/int8/batchnorm_int8_coder.cc | 4 +- .../nnacl/int8/batchnorm_int8_coder.h | 2 +- .../opcoders/nnacl/int8/concat_int8_coder.cc | 8 +- .../opcoders/nnacl/int8/concat_int8_coder.h | 2 +- .../nnacl/int8/conv2d_1x1_int8_coder.cc | 12 +- .../nnacl/int8/conv2d_1x1_int8_coder.h | 2 +- .../nnacl/int8/conv2d_3x3_int8_coder.cc | 6 +- .../nnacl/int8/conv2d_3x3_int8_coder.h | 2 +- .../opcoders/nnacl/int8/conv2d_int8_coder.cc | 4 +- .../opcoders/nnacl/int8/conv2d_int8_coder.h | 2 +- .../int8/convolution_depthwise_int8_coder.cc | 6 +- .../nnacl/int8/deconvolution_int8_coder.cc | 4 +- .../nnacl/int8/deconvolution_int8_coder.h | 2 +- .../int8/detection_post_process_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/div_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/div_int8_coder.h | 2 +- .../nnacl/int8/fullconnection_int8_coder.cc | 2 +- .../nnacl/int8/fullconnection_int8_coder.h | 4 +- .../opcoders/nnacl/int8/gather_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/gather_int8_coder.h | 6 +- .../nnacl/int8/leaky_relu_int8_coder.cc | 2 +- .../nnacl/int8/leaky_relu_int8_coder.h | 6 +- .../nnacl/int8/matmul_base_int8_coder.cc | 10 +- .../nnacl/int8/matmul_base_int8_coder.h | 2 +- .../opcoders/nnacl/int8/matmul_int8_coder.h | 2 +- .../opcoders/nnacl/int8/pad_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/pad_int8_coder.h | 6 +- .../opcoders/nnacl/int8/pooling_int8_coder.cc | 8 +- .../opcoders/nnacl/int8/pooling_int8_coder.h | 2 +- .../opcoders/nnacl/int8/prelu_int8_coder.h | 2 +- .../opcoders/nnacl/int8/reduce_int8_coder.cc | 6 +- .../opcoders/nnacl/int8/reduce_int8_coder.h | 4 +- .../opcoders/nnacl/int8/relux_int8_coder.cc | 4 +- .../opcoders/nnacl/int8/relux_int8_coder.h | 2 +- .../opcoders/nnacl/int8/reshape_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/resize_int8_coder.cc | 4 +- .../opcoders/nnacl/int8/resize_int8_coder.h | 2 +- .../opcoders/nnacl/int8/sigmoid_int8_coder.cc | 2 +- .../opcoders/nnacl/int8/softmax_int8_coder.cc | 4 +- .../opcoders/nnacl/int8/sub_int8_coder.cc | 3 +- .../opcoders/nnacl/int8/sub_int8_coder.h | 2 +- .../opcoders/nnacl/int8/tanh_int8_coder.cc | 4 +- .../nnacl/int8/transpose_int8_coder.cc | 2 +- .../nnacl/int8/transpose_int8_coder.h | 2 +- .../nnacl_serializer/nnacl_fp32_serializer.h | 64 +- .../nnacl_serializer/nnacl_int8_serializer.h | 34 +- .../nnacl_serializer/nnacl_stream_utils.cc | 10 +- .../nnacl_serializer/nnacl_stream_utils.h | 12 +- .../converter/micro/coder/utils/type_cast.h | 2 +- .../micro/coder/wrapper/base/affine_wrapper.h | 2 +- .../micro/coder/wrapper/base/common_wrapper.h | 2 +- .../coder/wrapper/base/micro_parameter.h | 2 +- .../wrapper/base/optimize_handler_wrapper.h | 2 +- .../wrapper/base/strided_slice_wrapper.c | 2 +- .../wrapper/base/strided_slice_wrapper.h | 2 +- .../wrapper/fp32/activation_fp32_wrapper.c | 4 +- .../wrapper/fp32/activation_fp32_wrapper.h | 2 +- .../wrapper/fp32/arithmetic_fp32_wrapper.h | 2 +- .../coder/wrapper/fp32/concat_fp32_wrapper.c | 4 +- .../coder/wrapper/fp32/conv_fp32_wrapper.c | 4 +- .../coder/wrapper/fp32/conv_fp32_wrapper.h | 2 +- .../wrapper/fp32/conv_winograd_fp32_wrapper.c | 2 +- .../wrapper/fp32/conv_winograd_fp32_wrapper.h | 4 +- .../wrapper/fp32/deconvolution_fp32_wrapper.c | 4 +- .../wrapper/fp32/deconvolution_fp32_wrapper.h | 6 +- .../coder/wrapper/fp32/fill_fp32_wrapper.c | 4 +- .../coder/wrapper/fp32/matmul_fp32_wrapper.c | 2 +- .../coder/wrapper/fp32/matmul_fp32_wrapper.h | 2 +- .../coder/wrapper/fp32/pooling_fp32_wrapper.c | 4 +- .../coder/wrapper/fp32/pooling_fp32_wrapper.h | 2 +- .../coder/wrapper/fp32/scale_fp32_wrapper.c | 4 +- .../coder/wrapper/fp32/scale_fp32_wrapper.h | 2 +- .../coder/wrapper/fp32/slice_fp32_wrapper.c | 2 +- .../coder/wrapper/fp32/slice_fp32_wrapper.h | 4 +- .../coder/wrapper/fp32/split_fp32_wrapper.c | 2 +- .../coder/wrapper/fp32/split_fp32_wrapper.h | 2 +- .../wrapper/fp32/transpose_fp32_wrapper.c | 4 +- .../wrapper/fp32/transpose_fp32_wrapper.h | 4 +- .../coder/wrapper/int8/add_int8_wrapper.c | 2 +- .../coder/wrapper/int8/add_int8_wrapper.h | 6 +- .../wrapper/int8/batchnorm_int8_wrapper.c | 4 +- .../wrapper/int8/batchnorm_int8_wrapper.h | 2 +- .../coder/wrapper/int8/concat_int8_wrapper.h | 6 +- .../wrapper/int8/conv1x1_init_int8_wrapper.c | 4 +- .../wrapper/int8/conv1x1_init_int8_wrapper.h | 2 +- .../wrapper/int8/conv1x1_run_int8_wrapper.c | 10 +- .../wrapper/int8/conv1x1_run_int8_wrapper.h | 4 +- .../wrapper/int8/conv3x3_run_int8_wrapper.h | 6 +- .../wrapper/int8/conv_init_int8_wrapper.c | 6 +- .../int8/convolution_depthwise_int8_wrapper.h | 6 +- .../wrapper/int8/convolution_int8_wrapper.h | 8 +- .../coder/wrapper/int8/matmul_int8_wrapper.h | 2 +- .../coder/wrapper/int8/resize_int8_wrapper.c | 2 +- .../coder/wrapper/int8/resize_int8_wrapper.h | 2 +- .../coder/wrapper/int8/slice_int8_wrapper.c | 2 +- .../coder/wrapper/int8/slice_int8_wrapper.h | 4 +- .../wrapper/thread/micro_core_affinity.c | 2 +- .../micro/providers/nnie/nnie_micro.h | 4 +- .../converter/offline_packing_optimizer.cc | 2 +- mindspore-lite/tools/converter/ops/while.cc | 2 +- .../tools/converter/optimizer_manager.cc | 2 +- .../parser/caffe/caffe_activation_parser.cc | 2 +- .../parser/caffe/caffe_argmax_parser.cc | 2 +- .../parser/caffe/caffe_batchnorm_parser.cc | 2 +- .../parser/caffe/caffe_concat_parser.cc | 2 +- .../parser/caffe/caffe_conv_base_parser.cc | 2 +- .../parser/caffe/caffe_convolution_parser.cc | 2 +- .../parser/caffe/caffe_crop_parser.cc | 2 +- .../caffe/caffe_deconvolution_parser.cc | 2 +- .../parser/caffe/caffe_eltwise_parser.cc | 2 +- .../parser/caffe/caffe_exp_parser.cc | 2 +- .../parser/caffe/caffe_flatten_parser.cc | 2 +- .../parser/caffe/caffe_innerproduct_parser.cc | 2 +- .../parser/caffe/caffe_interp_parser.cc | 2 +- .../parser/caffe/caffe_model_parser.cc | 2 +- .../parser/caffe/caffe_node_parser.cc | 2 +- .../parser/caffe/caffe_permute_parser.cc | 2 +- .../parser/caffe/caffe_pooling_parser.cc | 2 +- .../parser/caffe/caffe_power_parser.cc | 2 +- .../parser/caffe/caffe_prelu_parser.cc | 2 +- .../parser/caffe/caffe_quantize_parser.cc | 2 +- .../parser/caffe/caffe_reduce_parser.cc | 2 +- .../parser/caffe/caffe_reshape_parser.cc | 2 +- .../parser/caffe/caffe_scale_parser.cc | 2 +- .../parser/caffe/caffe_slice_parser.cc | 2 +- .../parser/caffe/caffe_softmax_parser.cc | 2 +- .../parser/caffe/caffe_tile_parser.cc | 2 +- .../parser/caffe/caffe_upsample_parser.cc | 2 +- .../converter/parser/conv1d_inout_adjust.cc | 2 +- .../parser/conv2d_transpose_input_adjust.cc | 2 +- .../tools/converter/parser/einsum_adjust.cc | 2 +- .../tools/converter/parser/inputs_adjust.cc | 2 +- .../converter/parser/lstm_adjust_pass.cc | 2 +- .../parser/onnx/onnx_activation_parser.cc | 2 +- .../parser/onnx/onnx_adder_parser.cc | 2 +- .../parser/onnx/onnx_argmax_parser.cc | 2 +- .../parser/onnx/onnx_argmin_parser.cc | 2 +- .../onnx/onnx_arithmetic_operation_parser.cc | 2 +- .../parser/onnx/onnx_batchnorm_parser.cc | 2 +- .../parser/onnx/onnx_biasadd_parser.cc | 2 +- .../converter/parser/onnx/onnx_cast_parser.cc | 2 +- .../converter/parser/onnx/onnx_clip_parser.cc | 2 +- .../parser/onnx/onnx_col2im_parser.cc | 2 +- .../parser/onnx/onnx_concat_parser.cc | 2 +- .../onnx/onnx_constant_of_shape_parser.cc | 2 +- .../parser/onnx/onnx_constant_parser.cc | 2 +- .../parser/onnx/onnx_conv2d_add_parser.cc | 2 +- .../parser/onnx/onnx_conv_base_parser.cc | 2 +- .../converter/parser/onnx/onnx_conv_parser.cc | 2 +- .../parser/onnx/onnx_conv_transpose_parser.cc | 2 +- .../parser/onnx/onnx_custom_op_adjust.cc | 2 +- .../parser/onnx/onnx_custom_op_parser.cc | 2 +- .../parser/onnx/onnx_deform_conv2d_parser.cc | 2 +- .../parser/onnx/onnx_depth_to_space_parser.cc | 2 +- .../onnx/onnx_dequantize_linear_parser.cc | 2 +- .../parser/onnx/onnx_dropout_parser.cc | 2 +- .../parser/onnx/onnx_einsum_parser.cc | 2 +- .../converter/parser/onnx/onnx_erf_parser.cc | 2 +- .../parser/onnx/onnx_expand_parser.cc | 2 +- .../onnx/onnx_flash_attention_parser.cc | 2 +- .../parser/onnx/onnx_gather_element_parser.cc | 2 +- .../parser/onnx/onnx_gather_nd_parser.cc | 2 +- .../parser/onnx/onnx_gather_parser.cc | 2 +- .../onnx/onnx_given_tensor_fill_parser.cc | 2 +- .../parser/onnx/onnx_gridsample3d_parser.cc | 2 +- .../parser/onnx/onnx_gridsample_parser.cc | 2 +- .../converter/parser/onnx/onnx_gru_parser.cc | 2 +- .../parser/onnx/onnx_hardswish_parser.cc | 2 +- .../parser/onnx/onnx_identity_parser.cc | 2 +- .../converter/parser/onnx/onnx_if_parser.cc | 2 +- .../parser/onnx/onnx_inputs_adjust.cc | 2 +- .../parser/onnx/onnx_instance_norm_parser.cc | 2 +- .../parser/onnx/onnx_layer_norm_parser.cc | 2 +- .../parser/onnx/onnx_less_or_equal_parser.cc | 2 +- .../parser/onnx/onnx_log_softmax_parser.cc | 2 +- .../converter/parser/onnx/onnx_loop_parser.cc | 2 +- .../parser/onnx/onnx_lp_norm_parser.cc | 2 +- .../converter/parser/onnx/onnx_lrn_parser.cc | 2 +- .../converter/parser/onnx/onnx_lstm_parser.cc | 2 +- .../parser/onnx/onnx_matmul_parser.cc | 2 +- .../parser/onnx/onnx_model_parser.cc | 2 +- .../converter/parser/onnx/onnx_node_parser.cc | 2 +- .../onnx/onnx_non_max_suppression_parser.cc | 2 +- .../parser/onnx/onnx_nonzero_parser.cc | 2 +- .../parser/onnx/onnx_onehot_parser.cc | 2 +- .../converter/parser/onnx/onnx_pad_adjust.cc | 2 +- .../converter/parser/onnx/onnx_pad_parser.cc | 2 +- .../converter/parser/onnx/onnx_pool_parser.cc | 2 +- .../onnx_prompt_flash_attention_parser.cc | 2 +- .../onnx/onnx_quantize_linear_adjust.cc | 2 +- .../onnx/onnx_quantize_linear_parser.cc | 2 +- .../parser/onnx/onnx_quantize_parser.cc | 2 +- .../parser/onnx/onnx_random_normal_parser.cc | 2 +- .../parser/onnx/onnx_range_parser.cc | 2 +- .../parser/onnx/onnx_reduce_parser.cc | 2 +- .../parser/onnx/onnx_reshape_parser.cc | 2 +- .../parser/onnx/onnx_resize_parser.cc | 2 +- .../onnx/onnx_reverse_sequence_parser.cc | 2 +- .../onnx/onnx_scatter_elements_parser.cc | 2 +- .../parser/onnx/onnx_scatter_nd_parser.cc | 2 +- .../parser/onnx/onnx_shape_parser.cc | 2 +- .../parser/onnx/onnx_slice_parser.cc | 2 +- .../parser/onnx/onnx_softmax_parser.cc | 2 +- .../parser/onnx/onnx_space_to_depth_parser.cc | 2 +- .../parser/onnx/onnx_splice_parser.cc | 2 +- .../parser/onnx/onnx_split_parser.cc | 2 +- .../parser/onnx/onnx_squeeze_parser.cc | 2 +- .../converter/parser/onnx/onnx_tile_parser.cc | 2 +- .../converter/parser/onnx/onnx_topk_parser.cc | 2 +- .../parser/onnx/onnx_transpose_parser.cc | 2 +- .../parser/onnx/onnx_trilu_parser.cc | 2 +- .../parser/onnx/onnx_unsqueeze_parser.cc | 2 +- .../parser/onnx/onnx_upsample_parser.cc | 2 +- .../parser/onnx/onnx_where_parser.cc | 2 +- .../tools/converter/parser/parser_utils.cc | 2 +- .../pytorch/pytorch_activation_parser.cc | 2 +- .../parser/pytorch/pytorch_argmax_parser.cc | 2 +- .../pytorch/pytorch_arithmetic_parser.cc | 2 +- .../pytorch/pytorch_batchnorm_parser.cc | 2 +- .../parser/pytorch/pytorch_conv_parser.cc | 2 +- .../parser/pytorch/pytorch_cumsum_parser.cc | 2 +- .../pytorch/pytorch_elementop_parser.cc | 2 +- .../pytorch/pytorch_embedding_parser.cc | 2 +- .../parser/pytorch/pytorch_flatten_parser.cc | 2 +- .../parser/pytorch/pytorch_gather_parser.cc | 2 +- .../pytorch/pytorch_list_construct_parser.cc | 2 +- .../pytorch/pytorch_logsoftmax_parser.cc | 2 +- .../parser/pytorch/pytorch_lstm_adjust.cc | 2 +- .../parser/pytorch/pytorch_lstm_parser.cc | 2 +- .../parser/pytorch/pytorch_matmul_parser.cc | 2 +- .../parser/pytorch/pytorch_model_parser.cc | 2 +- .../parser/pytorch/pytorch_node_parser.h | 2 +- .../pytorch_non_max_suppression_parser.cc | 2 +- .../parser/pytorch/pytorch_permute_parser.cc | 2 +- .../parser/pytorch/pytorch_pool_parser.cc | 2 +- .../parser/pytorch/pytorch_pow_parser.cc | 2 +- .../parser/pytorch/pytorch_reshape_parser.cc | 2 +- .../parser/pytorch/pytorch_split_parser.cc | 2 +- .../parser/pytorch/pytorch_to_parser.cc | 2 +- .../parser/pytorch/pytorch_unaryop_parser.cc | 2 +- .../parser/pytorch/pytorch_unstack_parser.cc | 2 +- .../converter/parser/tf/functionalize_cond.cc | 2 +- .../tf/functionalize_control_op_pass.cc | 2 +- .../tf/remove_ineffective_control_flow.cc | 2 +- .../parser/tf/tf_fake_quant_parser.cc | 2 +- .../converter/parser/tf/tf_node_parser.h | 2 +- .../parser/tf/tf_sparse_to_dense_parser.cc | 2 +- .../parser/tf_bidirection_gru_cf_fusion.cc | 2 +- .../parser/tflite/tflite_activation_parser.cc | 2 +- .../parser/tflite/tflite_addn_parser.cc | 2 +- .../parser/tflite/tflite_argmax_parser.cc | 2 +- .../parser/tflite/tflite_argmin_parser.cc | 2 +- .../parser/tflite/tflite_arithmetic_parser.cc | 2 +- .../tflite/tflite_batch_matmul_parser.cc | 2 +- .../tflite/tflite_batch_to_space_parser.cc | 2 +- .../tflite/tflite_broadcast_to_parser.cc | 2 +- .../parser/tflite/tflite_cast_parser.cc | 2 +- .../parser/tflite/tflite_concat_parser.cc | 2 +- .../parser/tflite/tflite_conv_parser.cc | 2 +- .../tflite/tflite_conv_transpose_parser.cc | 2 +- .../parser/tflite/tflite_custom_parser.cc | 2 +- .../tflite/tflite_depth_to_space_parser.cc | 2 +- .../parser/tflite/tflite_dequantize_parser.cc | 2 +- .../tflite/tflite_expand_dims_parser.cc | 2 +- .../parser/tflite/tflite_fill_parser.cc | 2 +- .../tflite/tflite_fullyconnected_parser.cc | 2 +- .../parser/tflite/tflite_gather_nd_parser.cc | 2 +- .../parser/tflite/tflite_gather_parser.cc | 2 +- .../tflite/tflite_hashtable_lookup_parser.cc | 2 +- .../parser/tflite/tflite_if_parser.cc | 2 +- .../parser/tflite/tflite_inputs_adjust.cc | 2 +- .../parser/tflite/tflite_l2norm_parser.cc | 2 +- .../tflite/tflite_log_softmax_parser.cc | 2 +- .../parser/tflite/tflite_logical_parser.cc | 2 +- .../parser/tflite/tflite_lrn_parser.cc | 2 +- .../tflite/tflite_lsh_projection_parser.cc | 2 +- .../parser/tflite/tflite_matmul_parser.cc | 2 +- .../parser/tflite/tflite_model_parser.cc | 2 +- .../parser/tflite/tflite_one_hot_parser.cc | 2 +- .../parser/tflite/tflite_pad_parser.cc | 2 +- .../parser/tflite/tflite_pooling_parser.cc | 2 +- .../parser/tflite/tflite_quantize_parser.cc | 2 +- .../parser/tflite/tflite_range_parser.cc | 2 +- .../parser/tflite/tflite_rank_parser.cc | 2 +- .../parser/tflite/tflite_reduce_parser.cc | 2 +- .../parser/tflite/tflite_reshape_parser.cc | 2 +- .../parser/tflite/tflite_resize_parser.cc | 2 +- .../parser/tflite/tflite_reverse_parser.cc | 2 +- .../tflite/tflite_reverse_sequence_parser.cc | 2 +- .../parser/tflite/tflite_scatter_nd_parser.cc | 2 +- .../parser/tflite/tflite_shape_parser.cc | 2 +- .../parser/tflite/tflite_skip_gram_parser.cc | 2 +- .../parser/tflite/tflite_slice_parser.cc | 2 +- .../parser/tflite/tflite_softmax_parser.cc | 2 +- .../tflite/tflite_space_to_batch_nd_parser.cc | 2 +- .../tflite/tflite_space_to_depth_parser.cc | 2 +- .../tflite/tflite_sparse_to_dense_parser.cc | 2 +- .../parser/tflite/tflite_split_parser.cc | 2 +- .../parser/tflite/tflite_split_v_parser.cc | 2 +- .../parser/tflite/tflite_squeeze_parser.cc | 2 +- .../parser/tflite/tflite_stack_parser.cc | 2 +- .../tflite/tflite_strided_slice_parser.cc | 2 +- .../parser/tflite/tflite_tile_parser.cc | 2 +- .../parser/tflite/tflite_topk_v2_parser.cc | 2 +- .../parser/tflite/tflite_transpose_parser.cc | 2 +- .../parser/tflite/tflite_unique_parser.cc | 2 +- .../parser/tflite/tflite_unstack_parser.cc | 2 +- .../converter/parser/tflite/tflite_util.cc | 2 +- .../parser/tflite/tflite_where_parser.cc | 2 +- .../parser/tflite/tflite_while_parser.cc | 2 +- .../parser/tflite/tflite_zeros_like_parser.cc | 2 +- .../tools/converter/parser/unify_format.cc | 2 +- .../tools/converter/quantizer/cle_pattern.cc | 2 +- .../converter/quantizer/debug_info_manager.h | 2 +- .../tools/converter/quantizer/fse_decoder.cc | 2 +- .../tools/converter/quantizer/fse_encoder.cc | 2 +- .../quantizer/full_quant_quantizer.cc | 2 +- .../tools/converter/quantizer/gptq.h | 2 +- .../converter/quantizer/gptq_quantizer.h | 2 +- .../ascend_distribute_fake_quant_transform.cc | 2 +- .../quantizer/quant_helper/ffn_full_quant.cc | 2 +- .../converter/quantizer/quant_param_holder.h | 2 +- .../converter/quantizer/quant_strategy.cc | 2 +- .../converter/quantizer/quantize_util.cc | 2 +- .../tools/converter/quantizer/smooth_quant.cc | 4 +- .../converter/quantizer/split_shared_bias.cc | 2 +- .../tools/converter/registry/CMakeLists.txt | 2 +- .../registry/model_parser_registry.cc | 2 +- .../tools/converter/registry/pass_registry.cc | 2 +- .../tools/cropper/build_cropper_config.sh | 70 +- .../tools/graph_kernel/common/infer_shape.cc | 6 +- .../tools/graph_kernel/common/utils.h | 2 +- .../tools/graph_kernel/runtime/akg_kernel.h | 2 +- .../tools/lite_exporter/anf_exporter.cc | 2 +- .../tools/lite_exporter/fetch_content.cc | 2 +- .../tools/lite_exporter/fetch_content.h | 2 +- .../tools/optimizer/common/format_utils.cc | 2 +- .../tools/optimizer/common/gllo_utils.cc | 2 +- .../tools/optimizer/common/helper.cc | 2 +- .../common/multiple_pattern_process_pass.cc | 2 +- .../const_fold/fold_along_infershape.cc | 2 +- .../const_fold/fold_with_infershape.cc | 2 +- .../fisson/eliminate_concat_split.cc | 2 +- .../tools/optimizer/fisson/fisson_util.cc | 2 +- .../optimizer/fisson/iter_node_outputs.cc | 2 +- .../optimizer/fisson/multi_conv_split_pass.cc | 2 +- .../tools/optimizer/fisson/node_out_shapes.cc | 2 +- .../format/delete_redundant_transpose.cc | 2 +- .../tools/optimizer/format/to_format_base.cc | 2 +- .../optimizer/fusion/activation_fusion.cc | 12 +- .../optimizer/fusion/add_activation_fusion.cc | 2 +- .../fusion/add_concat_activation_fusion.cc | 2 +- .../optimizer/fusion/add_layernorm_fusion.cc | 2 +- .../optimizer/fusion/adjust_col2im_pass.cc | 2 +- .../fusion/affine_activation_fusion.cc | 2 +- .../tools/optimizer/fusion/affine_fusion.cc | 2 +- ...tiquant_add_mul_matmul_allreduce_fusion.cc | 2 +- .../optimizer/fusion/batchmatmul_fusion.cc | 2 +- .../fusion/batchnorm_to_scale_fusion.cc | 2 +- .../tools/optimizer/fusion/cast_fusion.cc | 2 +- .../fusion/conv_activation_fusion.cc | 2 +- .../optimizer/fusion/conv_biasadd_fusion.cc | 2 +- .../tools/optimizer/fusion/conv_bn_fusion.cc | 2 +- .../optimizer/fusion/conv_conv_fusion.cc | 2 +- .../tools/optimizer/fusion/conv_pad_fusion.cc | 2 +- .../optimizer/fusion/conv_scale_fusion.cc | 2 +- .../optimizer/fusion/conv_transform_fusion.cc | 2 +- .../fusion/conv_tuple_activation_fusion.cc | 2 +- .../fusion/conv_tuplegetitem_fusion.cc | 2 +- .../optimizer/fusion/decoder_layer_fusion.cc | 2 +- .../optimizer/fusion/encoder_layer_fusion.cc | 5 +- .../fusion/expanddims_reshape_fusion.cc | 2 +- .../tools/optimizer/fusion/ffn_custom_pass.cc | 2 +- .../tools/optimizer/fusion/ffn_fusion.cc | 2 +- .../flash_attention_fusion_for_custom.cc | 2 +- .../fusion/flash_attention_tik_fusion.cc | 2 +- .../fusion/fullconnected_add_fusion.cc | 2 +- .../optimizer/fusion/fullconnected_fusion.cc | 2 +- .../tools/optimizer/fusion/gelu_fusion.cc | 2 +- .../tools/optimizer/fusion/glu_fusion.cc | 2 +- .../optimizer/fusion/groupnorm_fusion.cc | 2 +- .../optimizer/fusion/hard_swish_fusion.cc | 2 +- .../fusion/kv_cache_mgr_assign_fusion.cc | 2 +- .../fusion/kv_cache_mgr_concat_fusion.cc | 2 +- .../fusion/kv_cache_mgr_load_fusion.cc | 2 +- .../fusion/kv_cache_mgr_one_branch_fusion.cc | 2 +- .../optimizer/fusion/leaky_relu_fusion.cc | 2 +- .../fusion/matmul_activation_fusion.cc | 2 +- .../optimizer/fusion/matmul_add_fusion.cc | 2 +- .../fusion/matmul_allreduce_fusion.cc | 2 +- .../optimizer/fusion/matmul_mul_fusion.cc | 2 +- .../optimizer/fusion/matmul_scale_fusion.cc | 2 +- .../optimizer/fusion/mul_activation_fusion.cc | 2 +- .../tools/optimizer/fusion/mul_add_fusion.cc | 2 +- .../optimizer/fusion/mul_reduce_fusion.cc | 2 +- .../fusion/multi_head_attention_fusion.cc | 2 +- .../tools/optimizer/fusion/norm_fusion.cc | 2 +- .../optimizer/fusion/onnx_gelu_fusion.cc | 2 +- .../tools/optimizer/fusion/prelu_fusion.cc | 2 +- .../fusion/quant_dtype_cast_fusion.cc | 2 +- .../optimizer/fusion/reduce_stack_fusion.cc | 2 +- .../fusion/remove_transitivity_op.cc | 2 +- .../fusion/reshape_like_operator_ablation.cc | 2 +- .../optimizer/fusion/reshape_reduce_fusion.cc | 2 +- .../fusion/reshape_reshape_fusion.cc | 2 +- .../optimizer/fusion/reshape_shape_fusion.cc | 2 +- .../fusion/reshape_transpose_fusion.cc | 2 +- .../tools/optimizer/fusion/resize_fusion.cc | 2 +- .../fusion/scale_activation_fusion.cc | 2 +- .../optimizer/fusion/scale_base_fusion.cc | 2 +- .../optimizer/fusion/scale_scale_fusion.cc | 2 +- .../optimizer/fusion/sigmoid_mul_fusion.cc | 2 +- .../fusion/squeeze_expanddims_fusion.cc | 2 +- .../tools/optimizer/fusion/squeeze_fusion.cc | 2 +- .../optimizer/fusion/strided_slice_fusion.cc | 2 +- .../optimizer/fusion/tensor_dot_fusion.cc | 2 +- .../fusion/tf_bidirection_gru_fusion.cc | 2 +- .../tools/optimizer/fusion/tf_gelu_fusion.cc | 2 +- .../optimizer/fusion/tf_lstm_cell_fusion.cc | 2 +- .../fusion/tflite_lstm_cell_fusion.cc | 2 +- ...ite_rel_pos_multi_head_attention_fusion.cc | 2 +- .../optimizer/fusion/tile_matmul_fusion.cc | 2 +- .../optimizer/fusion/transpose_fusion.cc | 2 +- .../fusion/transpose_gather_fusion.cc | 2 +- .../fusion/transpose_matmul_fusion.cc | 2 +- .../tools/optimizer/graph/add_tensor_array.cc | 2 +- .../optimizer/graph/attr_to_args_pass.cc | 2 +- .../optimizer/graph/broadcast_for_select.cc | 7 +- .../graph/clip_convert_activation_pass.cc | 2 +- .../optimizer/graph/control_flow_pass.cc | 2 +- .../optimizer/graph/core_infershape_pass.cc | 2 +- .../graph/decrease_transpose_algo.cc | 2 +- .../graph/group_depthwise_op_convert_pass.cc | 2 +- .../tools/optimizer/graph/infershape_pass.cc | 2 +- .../graph/input_data_type_trans_pass.cc | 2 +- .../optimizer/graph/int64_cast_int32_pass.cc | 2 +- .../optimizer/graph/lite_tensor_extractor.cc | 2 +- .../optimizer/graph/miniaturization_pass.cc | 2 +- .../optimizer/graph/mul_constant_pass.cc | 2 +- .../tools/optimizer/graph/node_infershape.cc | 2 +- .../graph/preprocess_dynamic_shape.cc | 2 +- .../graph/redundant_op_remove_pass.cc | 2 +- .../optimizer/graph/slice_prepose_pass.cc | 2 +- .../graph/special_node_postprocess.cc | 2 +- .../graph/specify_graph_input_format.cc | 2 +- .../graph/specify_graph_output_format.cc | 2 +- .../optimizer/graph/transpose_strategy.cc | 2 +- .../unused_transpose_node_remove_pass.cc | 2 +- .../tools/optimizer/parallel/conv2d_info.cc | 2 +- .../parallel/depthwise_conv2d_info.cc | 2 +- .../optimizer/parallel/multi_conv_info.cc | 2 +- .../optimizer/parallel/multi_node_split.cc | 2 +- .../tools/optimizer/parallel/operator_info.cc | 2 +- .../tools/optimizer/parallel/parallel_pass.cc | 2 +- .../optimizer/parallel/split_strategy.cc | 2 +- 3053 files changed, 236400 insertions(+), 2253 deletions(-) create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h create mode 100644 mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 4f562fc7..10e853ce 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -36,3 +36,7 @@ "mindspore-lite/mindspore-lite/examples/quick_start_micro/" "syntaxError" "mindspore-lite/mindspore-lite/python/src/pybind_module.cc" "syntaxError" "mindspore-lite/mindspore-lite/java/src/main/native/model.cpp" "unreadVariable" + +# nnacl +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/" "unreadVariable" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c" "unknownMacro" diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 7b1dc36a..60814c36 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -91,3 +91,6 @@ "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "legal/copyright" "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/ending_newline" "mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include" + +# nnacl +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/" "readability/casting" diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index c76e4a3d..b494fdf5 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -36,3 +36,208 @@ mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_vid # other mindspore-lite/mindspore-lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn mindspore-lite/mindspore-lite/providers/nnie/src/custom_infer.cc:mindspore::nnie::CustomInterface::Infer + +# nnacl +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c:StridedSliceInferShape +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c:CheckInputShapeValid +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:WinogradInputTransformFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:MaxPoolingWithQuantInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c:RegisterInfer +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c:RowMajor2Col12MajorStride +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c:RowMajor2Col8MajorStride +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:Conv3x3Fp16InputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c:Conv3x3Fp16FilterTransform +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c:PackNHWCToNCHWFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:InputTransform6x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:InputTransform8x8UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8InputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8FilterTransform +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c:Conv1x1PreOptPert +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c:PackNHWCToNCHWInt8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c:AvgPooling +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c:SWConv3x32AVXKernel, SWConv4x24AVXKernel, SWConv12x8AVXKernel, SWConv8x8AVXKernel, SWConv4x8AVXKernel, SWConv6x16AVXKernel, SWConv4x16AVXKernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c:Conv1x1SW3x32AVXKernel, Conv1x1SW4x24AVXKernel, Conv1x1SW12x8AVXKernel, Conv1x1SW8x8AVXKernel, Conv1x1SW4x8AVXKernel, Conv1x1SW6x16AVXKernel, Conv1x1SW4x16AVXKernel, Conv1x1SW1x32AVXKernel, Conv1x1SW1x24AVXKernel, Conv1x1SW1x16AVXKernel, Conv1x1SW1x8AVXKernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c:WinogradTransRight +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c:WinogradTransLeft +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c:WinogradTransLeft +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c:WinogradTransRight +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c:DeConvWgMerge +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c:MatmulFloatSse64Opt +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c:ConvWinogardFp32 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c:ConvWinogardFp32CutByBatch +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512 +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32 \ No newline at end of file diff --git a/mindspore-lite/CMakeLists.txt b/mindspore-lite/CMakeLists.txt index c497ee12..8acc712e 100644 --- a/mindspore-lite/CMakeLists.txt +++ b/mindspore-lite/CMakeLists.txt @@ -743,12 +743,12 @@ set(CORE_DIR ${TOP_DIR}/mindspore/mindspore/core) set(CORE_INC_DIR ${TOP_DIR}/mindspore/mindspore/core/include) set(CCSRC_DIR ${TOP_DIR}/mindspore/mindspore/ccsrc) set(OPS_DIR ${TOP_DIR}/mindspore/mindspore/ops) -set(NNACL_DIR ${OPS_DIR}/kernel/cpu/nnacl) +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/litert/kernel/cpu/nnacl_c) if(PLATFORM_MCU) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-incompatible-pointer-types") # set(MSLITE_DEPS_CMSIS on) - add_subdirectory(${NNACL_DIR} build/nnacl) + add_subdirectory(${NNACL_DIR} build/nnacl_c) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/cmake/cortex-m/ build) include(${TOP_DIR}/cmake/package_lite.cmake) return() @@ -1062,7 +1062,7 @@ if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) -add_subdirectory(${OPS_DIR}/kernel/cpu/nnacl build) +add_subdirectory(${NNACL_DIR} build) if(MSLITE_ENABLE_TOOLS) if(NOT MSLITE_COMPILE_TWICE) diff --git a/mindspore-lite/java/native/CMakeLists.txt b/mindspore-lite/java/native/CMakeLists.txt index ef9368c1..3cac3642 100644 --- a/mindspore-lite/java/native/CMakeLists.txt +++ b/mindspore-lite/java/native/CMakeLists.txt @@ -10,6 +10,7 @@ set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../..) set(MINDSPORE_DIR ${TOP_DIR}/mindspore) set(LITE_DIR ${TOP_DIR}/mindspore-lite) set(NEW_NATIVE_DIR ${LITE_DIR}/java/src/main/native) +set(NNACL_DIR ${LITE_DIR}/src/litert/kernel/cpu/nnacl_c) include(${LITE_DIR}/cmake/secure_option.cmake) include(${LITE_DIR}/cmake/compile_link_option.cmake) @@ -110,7 +111,7 @@ include_directories(${MINDSPORE_DIR}) ## api include include_directories(${MINDSPORE_DIR}/mindspore/core/include) ## core include include_directories(${MINDSPORE_DIR}/mindspore/core/mindrt) ## core include include_directories(${MINDSPORE_DIR}/mindspore/core/mindrt/include) ## core include -include_directories(${MINDSPORE_DIR}/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${TOP_DIR}/build) ## flatbuffers if(PLATFORM_ARM64 OR PLATFORM_ARM32) @@ -137,7 +138,7 @@ set(JNI_SRC ) set(CCSRC - ${MINDSPORE_DIR}/mindspore/ops/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ) if(MSLITE_ENABLE_PARALLEL_INFERENCE) diff --git a/mindspore-lite/java/native/common/jni_utils.h b/mindspore-lite/java/native/common/jni_utils.h index 3980a4d0..1f57a43a 100644 --- a/mindspore-lite/java/native/common/jni_utils.h +++ b/mindspore-lite/java/native/common/jni_utils.h @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" std::string RealPath(const char *path); diff --git a/mindspore-lite/minddata/CMakeLists.txt b/mindspore-lite/minddata/CMakeLists.txt index 4557fecc..a7b68b51 100644 --- a/mindspore-lite/minddata/CMakeLists.txt +++ b/mindspore-lite/minddata/CMakeLists.txt @@ -94,7 +94,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") include_directories("dataset/liteapi") include_directories("${TOP_DIR}/mindspore-lite") include_directories("${TOP_DIR}") - include_directories("${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu") + include_directories("${NNACL_DIR}/../") if(MSLITE_ENABLE_ACL) include_directories(${CCSRC_DIR}) @@ -105,7 +105,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ${TOP_DIR}/mindspore-lite/src/litert/cxx_api/tensor_utils.cc ${TOP_DIR}/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc ${TOP_DIR}/mindspore-lite/src/tensor.cc - ${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c + ${NNACL_DIR}/tensor_c_utils.c ${TOP_DIR}/mindspore-lite/src/common/utils.cc ${TOP_DIR}/mindspore-lite/src/common/string_util.cc) diff --git a/mindspore-lite/python/CMakeLists.txt b/mindspore-lite/python/CMakeLists.txt index 14747406..238d0950 100644 --- a/mindspore-lite/python/CMakeLists.txt +++ b/mindspore-lite/python/CMakeLists.txt @@ -17,7 +17,7 @@ if(Python3_FOUND) include_directories(${TOP_DIR}/mindspore/mindspore/core/include) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt/include) - include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/) + include_directories(${NNACL_DIR}/../) if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) add_compile_definitions(MSLITE_ENABLE_CLOUD_INFERENCE) diff --git a/mindspore-lite/src/CMakeLists.txt b/mindspore-lite/src/CMakeLists.txt index 52019938..0d2a5f57 100644 --- a/mindspore-lite/src/CMakeLists.txt +++ b/mindspore-lite/src/CMakeLists.txt @@ -2,7 +2,7 @@ add_compile_definitions(USE_ANDROID_LOG) set(LITE_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CORE_INC_DIR}) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${OPS_DIR}/kernel/include) set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../tools) diff --git a/mindspore-lite/src/common/common.h b/mindspore-lite/src/common/common.h index 0de6cb8f..83ea8e7a 100644 --- a/mindspore-lite/src/common/common.h +++ b/mindspore-lite/src/common/common.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_COMMON_COMMON_H_ #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" /* Naming a key of path must be consistent with existing naming styles and follow the following rules: diff --git a/mindspore-lite/src/common/graph_util.cc b/mindspore-lite/src/common/graph_util.cc index d5cb2a98..c0253913 100644 --- a/mindspore-lite/src/common/graph_util.cc +++ b/mindspore-lite/src/common/graph_util.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "src/common/version_manager.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/CMakeLists.txt b/mindspore-lite/src/common/ops/CMakeLists.txt index 0c6a666f..a6cfa820 100644 --- a/mindspore-lite/src/common/ops/CMakeLists.txt +++ b/mindspore-lite/src/common/ops/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) if(APPLE) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing -ffunction-sections \ -fdata-sections -ffast-math -fno-rtti -fno-exceptions -Wno-shorten-64-to-32 \ diff --git a/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc index 0299e0a2..e4681cba 100644 --- a/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/activation_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "infer/grad/activation_grad.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameActivationGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc b/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc index e2294c51..587a7753 100644 --- a/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/activation_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/activation.h" #include "infer/leaky_relu.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc b/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc index ec1936f7..12c525e5 100644 --- a/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/adder_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/adder.h" #include "infer/cxx_api/adder_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc b/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc index 4056c571..45ab33ed 100644 --- a/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/affine_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "infer/affine.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAffine; diff --git a/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc b/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc index 149d5983..fd7c3450 100644 --- a/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/all_gather_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" #include "infer/all_gather.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAllGather; diff --git a/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc index acbb5363..78a503ac 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/arg_minmax_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/arg_max_fusion.h" #include "infer/cxx_api/arg_min_fusion.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h b/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h index 9f514dae..c14e2f94 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h +++ b/mindspore-lite/src/common/ops/operator_populate/arithmetic_operator_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_OPERATOR_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/operator_populate/operator_populate_register.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc b/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc index 35a4a219..1c922365 100644 --- a/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/arithmetic_self_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/grad/log_grad.h" #include "infer/grad/neg_grad.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc b/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc index 128cb9c5..9b9470cf 100644 --- a/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/attention_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/attention_parameter.h" +#include "nnacl_c/attention_parameter.h" #include "infer/attention.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAttention; diff --git a/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc b/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc index a9930018..2aa371f2 100644 --- a/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/audio_spectrogram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" #include "infer/audio_spectrogram.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" using mindspore::ops::kNameAudioSpectrogram; diff --git a/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc index a42ac7f5..406854a3 100644 --- a/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/base_operator_populate.cc @@ -14,16 +14,16 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/where_parameter.h" -#include "nnacl/sparse_to_dense_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/triu_tril_parameter.h" -#include "nnacl/fp32/unique_fp32.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/op_base.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/gather_nd_parameter.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/triu_tril_parameter.h" +#include "nnacl_c/fp32/unique_fp32.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/gather_nd_parameter.h" +#include "nnacl_c/reshape_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/adam.h" #include "infer/assert.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc index 7b0248fd..a89aea9f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/batch_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" using mindspore::schema::PrimitiveType_BatchNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc b/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc index e5385ee5..ab8e965b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/batch_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" #include "infer/batch_to_space.h" #include "infer/batch_to_space_nd.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc b/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc index 8afe3ef4..d4a97f54 100644 --- a/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/broadcast_to_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/call_populate.cc b/mindspore-lite/src/common/ops/operator_populate/call_populate.cc index 785c7c55..a727ba02 100644 --- a/mindspore-lite/src/common/ops/operator_populate/call_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/call_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #include "infer/call.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCall; diff --git a/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc b/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc index 22eee0fd..491416fd 100644 --- a/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/clip_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/clip_parameter.h" +#include "nnacl_c/clip_parameter.h" #include "infer/clip.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameClip; diff --git a/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc b/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc index 76ff105a..e10d7e73 100644 --- a/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/concat_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConcat; diff --git a/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc index 499ea4f8..1fefda84 100644 --- a/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/constant_of_shape_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" #include "infer/constant_of_shape.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConstantOfShape; diff --git a/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc b/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc index d2762e87..f326917f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/conv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/conv2d.h" #include "infer/cxx_api/conv2d_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc b/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc index 801c8c2b..99d4dff4 100644 --- a/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/crop_and_resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "infer/crop_and_resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCropAndResize; diff --git a/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc b/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc index ad24e8e3..d52c98d9 100644 --- a/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/crop_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" #include "infer/crop.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCrop; diff --git a/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc b/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc index 83ceca36..3bbb8c45 100644 --- a/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/cumsum_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCumSum; diff --git a/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc b/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc index 54f21285..80e18750 100644 --- a/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/custom_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/custom_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/split_parameter.h" #include "infer/custom.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCustom; diff --git a/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc b/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc index b17e6075..46207a60 100644 --- a/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/custom_predict_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" #include "infer/custom_predict.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameCustomPredict; diff --git a/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc b/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc index 26a66012..494f7e68 100644 --- a/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/deconv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" using mindspore::ops::kNameConv2dTransposeFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc b/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc index 155038b0..28936315 100644 --- a/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/depth_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/depth_to_space_parameter.h" #include "infer/depth_to_space.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDepthToSpace; diff --git a/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc b/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc index 80d8e9fc..7a10ce92 100644 --- a/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/detection_post_process_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" #include "infer/detection_post_process.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDetectionPostProcess; diff --git a/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc b/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc index 6da61ae0..cbd34195 100644 --- a/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/dynamic_quant_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" #include "infer/dynamic_quant.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" using mindspore::ops::kNameDynamicQuant; diff --git a/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc b/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc index cd37e263..42e93a36 100644 --- a/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/embedding_lookup_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" #include "infer/embedding_lookup.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" using mindspore::ops::kMaxNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc b/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc index 03b61766..3c414b50 100644 --- a/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/exp_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/exp_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc b/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc index cf2cbc8b..e48de873 100644 --- a/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/flatten_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/flatten_parameter.h" +#include "nnacl_c/flatten_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFlatten; diff --git a/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc b/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc index 279bf12f..f850b7dd 100644 --- a/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/full_connection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "infer/cxx_api/full_connection.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFullConnection; diff --git a/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc index bb27d5c8..ba21e6cf 100644 --- a/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/fused_batchnorm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "infer/fused_batch_norm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" using mindspore::ops::kNameFusedBatchNorm; diff --git a/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc b/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc index 9d565e52..de7c0822 100644 --- a/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/glu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/glu_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kAxis; using mindspore::ops::kNameGLU; diff --git a/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc index f303eee7..ad9110b4 100644 --- a/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/group_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/group_norm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" #include "infer/cxx_api/groupnorm_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kNameGroupNormFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc b/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc index 014d1b63..0c1b0a6b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/gru_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/gru_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" #include "infer/gru.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" using mindspore::ops::kBidirectional; diff --git a/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc index c83da511..78498168 100644 --- a/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/instance_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" #include "infer/instance_norm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" using mindspore::ops::kEpsilon; diff --git a/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc index 8243dedc..219ddf16 100644 --- a/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/l2_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" #include "infer/l2_normalize.h" #include "infer/cxx_api/l2_normalize_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc index a3b2b32a..cf106825 100644 --- a/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/layer_norm_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLayerNormGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc index 8166bc66..0034b530 100644 --- a/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/layer_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/layer_norm_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc b/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc index d004ca57..8c806e6d 100644 --- a/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/local_response_normalization_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/local_response_norm_fp32.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" #include "infer/lrn.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLRN; diff --git a/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc index 10e5701a..c4e58609 100644 --- a/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/log_softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc b/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc index 6cd51b7f..a9eb9942 100644 --- a/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/lsh_projection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "infer/lsh_projection.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLshProjection; diff --git a/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc b/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc index 9ee910b8..d9faa31a 100644 --- a/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/lstm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" #include "infer/lstm.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" using mindspore::ops::kNameLSTM; diff --git a/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc b/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc index 939c686f..ed17cbd1 100644 --- a/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/matmul_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc b/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc index eaf48b97..613e9b59 100644 --- a/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/mfcc_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" #include "infer/mfcc.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" using mindspore::ops::kNameMfcc; diff --git a/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc b/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc index cf04225c..2a5fbdd7 100644 --- a/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/nllloss_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" using mindspore::ops::kNameNLLLoss; diff --git a/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc b/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc index 90b0907b..ace5ba69 100644 --- a/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/non_max_suppression_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" #include "infer/non_max_suppression.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" using mindspore::ops::kNameNonMaxSuppression; diff --git a/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc b/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc index 55761e8c..77a5f2ea 100644 --- a/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/one_hot_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" using mindspore::ops::kNameOneHot; diff --git a/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h b/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h index 204f486c..50718f0f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h +++ b/mindspore-lite/src/common/ops/operator_populate/operator_populate_register.h @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc b/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc index 35bf2746..225ffc5a 100644 --- a/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/p_relu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePReLUFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc index e2a90db1..d6940759 100644 --- a/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/pad_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "infer/cxx_api/pad_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePadFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc b/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc index 96c9855f..efc57b27 100644 --- a/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/partial_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" #include "infer/cxx_api/partial_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePartialFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc b/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc index 0339695c..2f8358ad 100644 --- a/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/pooling_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/power_populate.cc b/mindspore-lite/src/common/ops/operator_populate/power_populate.cc index 08001d52..7960d5d9 100644 --- a/mindspore-lite/src/common/ops/operator_populate/power_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/power_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" #include "infer/cxx_api/pow_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePowFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc b/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc index e58e8448..93ae68de 100644 --- a/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/prior_box_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/prior_box_parameter.h" +#include "nnacl_c/prior_box_parameter.h" #include "infer/prior_box.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" using mindspore::ops::kNamePriorBox; diff --git a/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc b/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc index acbb4c1c..40c3054d 100644 --- a/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/quant_dtype_cast_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "infer/quant_dtype_cast.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" using mindspore::ops::kNameQuantDTypeCast; diff --git a/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc b/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc index a293805f..bbf19eb5 100644 --- a/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/random_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/random_normal.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRandomNormal; diff --git a/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc b/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc index 7f4a060d..e2155b90 100644 --- a/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/random_standard_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/random_standard_normal.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRandomStandardNormal; diff --git a/mindspore-lite/src/common/ops/operator_populate/range_populate.cc b/mindspore-lite/src/common/ops/operator_populate/range_populate.cc index b2a65a81..4446c9e8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/range_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/range_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/range_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameRange; diff --git a/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc index 6742ef2c..73b738a4 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reduce_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "infer/reduce.h" #include "infer/cxx_api/reduce_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc b/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc index e308a5ce..d3b1f7b9 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reduce_scatter.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" #include "infer/reduce_scatter.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameReduceScatter; diff --git a/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc b/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc index db587a91..44259066 100644 --- a/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "infer/resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameResize; diff --git a/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc index d6a6a131..aad907d5 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reverse_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/reverse_fp32.h" +#include "nnacl_c/fp32/reverse_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc b/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc index 9a670f87..a4c2afa2 100644 --- a/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/reverse_sequence_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/reverse_sequence_parameter.h" +#include "nnacl_c/reverse_sequence_parameter.h" #include "infer/reverse_sequence.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameReverseSequence; diff --git a/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc b/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc index f0856d7b..fe3c8dd6 100644 --- a/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/roi_pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" #include "infer/roi_pooling.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" using mindspore::ops::kNameROIPooling; diff --git a/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc b/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc index 6530ef14..c51faba0 100644 --- a/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/scale_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" #include "infer/cxx_api/scale_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameScaleFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc b/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc index 0bef9f13..68c42696 100644 --- a/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/scatter_element_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/scatter_elements_parameter.h" +#include "nnacl_c/scatter_elements_parameter.h" #include "infer/scatter_elements.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameScatterElements; diff --git a/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc b/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc index d1e01509..494d7405 100644 --- a/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/skip_gram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "infer/skip_gram.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSkipGram; diff --git a/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc b/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc index bcf3808c..71d77374 100644 --- a/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/slice_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" #include "infer/cxx_api/slice_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSliceFusion; diff --git a/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc b/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc index 1ddd8b47..aa8b16c2 100644 --- a/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc index 8c33063d..302f527e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_nd_populate.cc @@ -15,7 +15,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "infer/space_to_batch_nd.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToBatchND; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc index efd292d8..39b8326f 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_batch_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "infer/space_to_batch.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToBatch; diff --git a/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc b/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc index c71e27a1..778c9783 100644 --- a/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/space_to_depth_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" #include "infer/space_to_depth.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSpaceToDepth; diff --git a/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc b/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc index 7261d819..f52e6de5 100644 --- a/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/sparse_softmax_cross_entropy_with_logits_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "infer/sparse_softmax_cross_entropy_with_logits.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSparseSoftmaxCrossEntropyWithLogits; diff --git a/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc b/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc index 345ae177..bffd6c9d 100644 --- a/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/splice_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/splice_parameter.h" #include "infer/splice.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSplice; diff --git a/mindspore-lite/src/common/ops/operator_populate/split_populate.cc b/mindspore-lite/src/common/ops/operator_populate/split_populate.cc index f20b6d35..3c2ba4f2 100644 --- a/mindspore-lite/src/common/ops/operator_populate/split_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/split_populate.cc @@ -15,8 +15,8 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/operator_populate/utils.h" -#include "nnacl/split_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc b/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc index 2dee4eea..01f028ac 100644 --- a/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/split_with_overlap_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" #include "infer/split_with_overlap.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameSplitWithOverlap; diff --git a/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc b/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc index 3a773056..68fd05f6 100644 --- a/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/squeeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc index 114b3a56..ba55af57 100644 --- a/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/stack_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" #include "infer/stack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStack; diff --git a/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc b/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc index 9ad629ba..f97fc71d 100644 --- a/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/strided_slice_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "infer/grad/strided_slice_grad.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStridedSliceGrad; diff --git a/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc index 84958329..91973e83 100644 --- a/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/strided_slice_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" using mindspore::ops::kNameStridedSlice; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc index 6093df85..b8cfdba1 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_array_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/tensor_array_parameter.h" #include "infer/tensor_array.h" #include "infer/tensor_array_read.h" #include "infer/tensor_array_write.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc index 722b01ef..3b872825 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_from_tensor_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_from_tensor.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListFromTensor; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc index 82878e87..49e41305 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_get_item_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_get_item.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kElement_dtype; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc index fb42bcca..61bbca3c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_reserve_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_reserve.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListReserve; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc index 16fbb7dc..44b8a8a8 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_set_item_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_set_item.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListSetItem; diff --git a/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc index 9b1611d0..a5dfa96c 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tensor_list_stack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "infer/tensor_list_stack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" using mindspore::ops::kNameTensorListStack; diff --git a/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc b/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc index 07767c58..3e60e3cb 100644 --- a/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/tile_operator_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" #include "infer/cxx_api/tile_fusion.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc b/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc index c3e38984..6d02fe2e 100644 --- a/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/topk_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "infer/cxx_api/topk_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc b/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc index c08c0435..df55e61d 100644 --- a/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/uniform_real_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/operator_populate/operator_populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "infer/uniform_real.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kNameUniformReal; diff --git a/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc b/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc index fcfb8a5e..b22ca15b 100644 --- a/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/unsqueeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" #include "infer/unsqueeze.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc b/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc index ab05ff38..49c195e4 100644 --- a/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc +++ b/mindspore-lite/src/common/ops/operator_populate/unstack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/operator_populate/operator_populate_register.h" -#include "nnacl/unstack_parameter.h" +#include "nnacl_c/unstack_parameter.h" #include "infer/unstack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" using mindspore::ops::kAxis; diff --git a/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc b/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc index 59a8c297..121a32da 100644 --- a/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/activation_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" using mindspore::schema::PrimitiveType_ActivationGrad; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/activation_populate.cc b/mindspore-lite/src/common/ops/populate/activation_populate.cc index 1cd1d966..893a57a0 100644 --- a/mindspore-lite/src/common/ops/populate/activation_populate.cc +++ b/mindspore-lite/src/common/ops/populate/activation_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" using mindspore::schema::PrimitiveType_Activation; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/adam_populate.cc b/mindspore-lite/src/common/ops/populate/adam_populate.cc index 6dd94f7b..119130e5 100644 --- a/mindspore-lite/src/common/ops/populate/adam_populate.cc +++ b/mindspore-lite/src/common/ops/populate/adam_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Adam; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/add_populate.cc b/mindspore-lite/src/common/ops/populate/add_populate.cc index 1de02b45..e529d2ae 100644 --- a/mindspore-lite/src/common/ops/populate/add_populate.cc +++ b/mindspore-lite/src/common/ops/populate/add_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/src/common/ops/populate/adder_populate.cc b/mindspore-lite/src/common/ops/populate/adder_populate.cc index a09e0064..337c23fe 100644 --- a/mindspore-lite/src/common/ops/populate/adder_populate.cc +++ b/mindspore-lite/src/common/ops/populate/adder_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/log_adapter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_AdderFusion; diff --git a/mindspore-lite/src/common/ops/populate/affine_populate.cc b/mindspore-lite/src/common/ops/populate/affine_populate.cc index 3e780d8c..7730a09c 100644 --- a/mindspore-lite/src/common/ops/populate/affine_populate.cc +++ b/mindspore-lite/src/common/ops/populate/affine_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/affine_parameter.h" using mindspore::schema::PrimitiveType_Affine; diff --git a/mindspore-lite/src/common/ops/populate/all_gather.cc b/mindspore-lite/src/common/ops/populate/all_gather.cc index ea8c1f1e..3c1e2b7d 100644 --- a/mindspore-lite/src/common/ops/populate/all_gather.cc +++ b/mindspore-lite/src/common/ops/populate/all_gather.cc @@ -16,7 +16,7 @@ #include "schema/ops_generated.h" #include "schema/model_generated.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_AllGather; diff --git a/mindspore-lite/src/common/ops/populate/argmax_populate.cc b/mindspore-lite/src/common/ops/populate/argmax_populate.cc index 49639aa3..06538206 100644 --- a/mindspore-lite/src/common/ops/populate/argmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/argmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" using mindspore::schema::PrimitiveType_ArgMaxFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/argmin_populate.cc b/mindspore-lite/src/common/ops/populate/argmin_populate.cc index 280d1edb..730daaf4 100644 --- a/mindspore-lite/src/common/ops/populate/argmin_populate.cc +++ b/mindspore-lite/src/common/ops/populate/argmin_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" using mindspore::schema::PrimitiveType_ArgMinFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/arithmetic_populate.h b/mindspore-lite/src/common/ops/populate/arithmetic_populate.h index 7601f25c..e79fa2c4 100644 --- a/mindspore-lite/src/common/ops/populate/arithmetic_populate.h +++ b/mindspore-lite/src/common/ops/populate/arithmetic_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc b/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc index 05e0253b..69e3725b 100644 --- a/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc +++ b/mindspore-lite/src/common/ops/populate/arithmetic_self_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/log_adapter.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Ceil; diff --git a/mindspore-lite/src/common/ops/populate/attention_populate.cc b/mindspore-lite/src/common/ops/populate/attention_populate.cc index 69c0bdd5..75c86cf0 100644 --- a/mindspore-lite/src/common/ops/populate/attention_populate.cc +++ b/mindspore-lite/src/common/ops/populate/attention_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/attention_parameter.h" +#include "nnacl_c/attention_parameter.h" using mindspore::schema::PrimitiveType_Attention; diff --git a/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc b/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc index b2edc51e..1eb73d61 100644 --- a/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc +++ b/mindspore-lite/src/common/ops/populate/audio_spectrogram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" using mindspore::schema::PrimitiveType_AudioSpectrogram; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc b/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc index 9923672e..7e2f5418 100644 --- a/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/batch_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::schema::PrimitiveType_BatchNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc b/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc index 4820b6c7..14105093 100644 --- a/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/populate/batch_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" using mindspore::schema::PrimitiveType_BatchToSpace; using mindspore::schema::PrimitiveType_BatchToSpaceND; diff --git a/mindspore-lite/src/common/ops/populate/bias_add_populate.cc b/mindspore-lite/src/common/ops/populate/bias_add_populate.cc index aa2b6593..bdac3bef 100644 --- a/mindspore-lite/src/common/ops/populate/bias_add_populate.cc +++ b/mindspore-lite/src/common/ops/populate/bias_add_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::PrimitiveType_BiasAdd; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc b/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc index a8c51d1c..59017fab 100644 --- a/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc +++ b/mindspore-lite/src/common/ops/populate/broadcast_to_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" using mindspore::schema::PrimitiveType_BroadcastTo; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/call_populate.cc b/mindspore-lite/src/common/ops/populate/call_populate.cc index 105ff5ee..67da1fff 100644 --- a/mindspore-lite/src/common/ops/populate/call_populate.cc +++ b/mindspore-lite/src/common/ops/populate/call_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" using mindspore::schema::PrimitiveType_Call; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/clip_populate.cc b/mindspore-lite/src/common/ops/populate/clip_populate.cc index df1da96a..40d8155e 100644 --- a/mindspore-lite/src/common/ops/populate/clip_populate.cc +++ b/mindspore-lite/src/common/ops/populate/clip_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/clip_parameter.h" +#include "nnacl_c/clip_parameter.h" using mindspore::schema::PrimitiveType_Clip; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/concat_populate.cc b/mindspore-lite/src/common/ops/populate/concat_populate.cc index 485f7114..e86126c0 100644 --- a/mindspore-lite/src/common/ops/populate/concat_populate.cc +++ b/mindspore-lite/src/common/ops/populate/concat_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" using mindspore::schema::PrimitiveType_Concat; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc index 231c2dec..519c9479 100644 --- a/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc +++ b/mindspore-lite/src/common/ops/populate/constant_of_shape_populate.cc @@ -15,7 +15,7 @@ */ #include "ir/dtype/type_id.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" using mindspore::schema::PrimitiveType_ConstantOfShape; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc index b834071a..6858ebf7 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensor_array_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_array_parameter.h" using mindspore::schema::PrimitiveType_TensorArray; using mindspore::schema::PrimitiveType_TensorArrayRead; diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc index cd72d1f9..83239162 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistfromtensor_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_TensorListFromTensor; diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc index 0374044d..64595636 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistgetitem_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListGetItem; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc index 306f2529..f7590986 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistreserve_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListReserve; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc index 7cfebab4..1e2107d7 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorlistsetlitem_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListSetItem; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc b/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc index b053c251..1e720291 100644 --- a/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/control/tensorliststack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" using mindspore::schema::PrimitiveType_TensorListStack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/conv2d_populate.cc b/mindspore-lite/src/common/ops/populate/conv2d_populate.cc index 79d8b792..3310cdb1 100644 --- a/mindspore-lite/src/common/ops/populate/conv2d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/conv2d_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_Conv2DFusion; diff --git a/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc b/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc index 096f51a6..47072764 100644 --- a/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc +++ b/mindspore-lite/src/common/ops/populate/crop_and_resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_CropAndResize; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/crop_populate.cc b/mindspore-lite/src/common/ops/populate/crop_populate.cc index 7db5c4b5..357a36b0 100644 --- a/mindspore-lite/src/common/ops/populate/crop_populate.cc +++ b/mindspore-lite/src/common/ops/populate/crop_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" using mindspore::schema::PrimitiveType_Crop; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/cumsum_populate.cc b/mindspore-lite/src/common/ops/populate/cumsum_populate.cc index 76fc45a3..af59a8a3 100644 --- a/mindspore-lite/src/common/ops/populate/cumsum_populate.cc +++ b/mindspore-lite/src/common/ops/populate/cumsum_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" using mindspore::schema::PrimitiveType_CumSum; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/custom_populate.cc b/mindspore-lite/src/common/ops/populate/custom_populate.cc index 0cde665b..84ea9ace 100644 --- a/mindspore-lite/src/common/ops/populate/custom_populate.cc +++ b/mindspore-lite/src/common/ops/populate/custom_populate.cc @@ -16,14 +16,14 @@ #include "src/common/ops/populate/populate_register.h" #include "src/common/log_adapter.h" #include "src/tensor.h" -#include "nnacl/custom_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/custom_masked_fill_parameter.h" -#include "nnacl/custom_is_inf_parameter.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/conv3d_parameter.h" -#include "nnacl/grid_sampler_parameter.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/custom_masked_fill_parameter.h" +#include "nnacl_c/custom_is_inf_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/conv3d_parameter.h" +#include "nnacl_c/grid_sampler_parameter.h" using mindspore::schema::PrimitiveType_Custom; diff --git a/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc b/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc index 4b6b6a76..25c8fcbf 100644 --- a/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/deconv2d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/log_adapter.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/default_populate.h b/mindspore-lite/src/common/ops/populate/default_populate.h index 6ee48376..d13aaa85 100644 --- a/mindspore-lite/src/common/ops/populate/default_populate.h +++ b/mindspore-lite/src/common/ops/populate/default_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_DEFAULT_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_DEFAULT_POPULATE_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc b/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc index 0c7f6a6b..5b833598 100644 --- a/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc +++ b/mindspore-lite/src/common/ops/populate/depth_to_space_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/depth_to_space_parameter.h" using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc b/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc index 6cfdb35c..bbb625d3 100644 --- a/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc +++ b/mindspore-lite/src/common/ops/populate/detection_post_process_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" using mindspore::schema::PrimitiveType_DetectionPostProcess; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc index 8e393320..86f732f9 100644 --- a/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc +++ b/mindspore-lite/src/common/ops/populate/dynamic_quant_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" using mindspore::schema::PrimitiveType_DynamicQuant; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc b/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc index 87b56c02..5173b69b 100644 --- a/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc +++ b/mindspore-lite/src/common/ops/populate/embedding_lookup_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" using mindspore::schema::PrimitiveType_EmbeddingLookupFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/exp_populate.cc b/mindspore-lite/src/common/ops/populate/exp_populate.cc index 86c02456..c14d969c 100644 --- a/mindspore-lite/src/common/ops/populate/exp_populate.cc +++ b/mindspore-lite/src/common/ops/populate/exp_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" using mindspore::schema::PrimitiveType_ExpFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/flatten_populate.cc b/mindspore-lite/src/common/ops/populate/flatten_populate.cc index f4f30e77..045bd536 100644 --- a/mindspore-lite/src/common/ops/populate/flatten_populate.cc +++ b/mindspore-lite/src/common/ops/populate/flatten_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/flatten_parameter.h" +#include "nnacl_c/flatten_parameter.h" using mindspore::schema::PrimitiveType_Flatten; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/full_connection_populate.cc b/mindspore-lite/src/common/ops/populate/full_connection_populate.cc index 30106e64..00b57404 100644 --- a/mindspore-lite/src/common/ops/populate/full_connection_populate.cc +++ b/mindspore-lite/src/common/ops/populate/full_connection_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" using mindspore::schema::PrimitiveType_FullConnection; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc b/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc index a23fb707..e5ad31d5 100644 --- a/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/fused_batchnorm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::schema::PrimitiveType_FusedBatchNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_d_populate.cc b/mindspore-lite/src/common/ops/populate/gather_d_populate.cc index b05fcfee..a1bee47b 100644 --- a/mindspore-lite/src/common/ops/populate/gather_d_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_d_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" using mindspore::schema::PrimitiveType_GatherD; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc b/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc index 980a1adf..cf79b247 100644 --- a/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/gather_nd_parameter.h" using mindspore::schema::PrimitiveType_GatherNd; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gather_populate.cc b/mindspore-lite/src/common/ops/populate/gather_populate.cc index 7e19ccd9..1ac8e829 100644 --- a/mindspore-lite/src/common/ops/populate/gather_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gather_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" using mindspore::schema::PrimitiveType_Gather; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/glu_populate.cc b/mindspore-lite/src/common/ops/populate/glu_populate.cc index 96d23266..4d2dbd89 100644 --- a/mindspore-lite/src/common/ops/populate/glu_populate.cc +++ b/mindspore-lite/src/common/ops/populate/glu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/glu_parameter.h" using mindspore::schema::PrimitiveType_GLU; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/group_norm_populate.cc b/mindspore-lite/src/common/ops/populate/group_norm_populate.cc index c832e705..59d3f6e7 100644 --- a/mindspore-lite/src/common/ops/populate/group_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/group_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/group_norm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" using mindspore::schema::PrimitiveType_GroupNormFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/gru_populate.cc b/mindspore-lite/src/common/ops/populate/gru_populate.cc index ed157b6b..df247441 100644 --- a/mindspore-lite/src/common/ops/populate/gru_populate.cc +++ b/mindspore-lite/src/common/ops/populate/gru_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/gru_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" using mindspore::schema::PrimitiveType_GRU; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc b/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc index 71acd6e3..dcc223c2 100644 --- a/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/instance_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::schema::PrimitiveType_InstanceNorm; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc b/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc index c1dc48da..46850753 100644 --- a/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/l2_norm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" using mindspore::schema::PrimitiveType_L2NormalizeFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc b/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc index fc41e4d6..498a23e2 100644 --- a/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/layer_norm_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LayerNormGrad; diff --git a/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc b/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc index 9da07bfc..5ee80565 100644 --- a/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/layer_norm_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LayerNormFusion; diff --git a/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc b/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc index 2b372e60..19439587 100644 --- a/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc +++ b/mindspore-lite/src/common/ops/populate/local_response_normalization_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/local_response_norm_fp32.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" using mindspore::schema::PrimitiveType_LRN; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc b/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc index 40ae66b3..fe720d40 100644 --- a/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/log_softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" using mindspore::schema::PrimitiveType_LogSoftmax; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/lstm_populate.cc b/mindspore-lite/src/common/ops/populate/lstm_populate.cc index b3a85b64..65c3d9ec 100644 --- a/mindspore-lite/src/common/ops/populate/lstm_populate.cc +++ b/mindspore-lite/src/common/ops/populate/lstm_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" using mindspore::schema::PrimitiveType_LSTM; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/matmul_populate.cc b/mindspore-lite/src/common/ops/populate/matmul_populate.cc index 8eb182b8..803bedda 100644 --- a/mindspore-lite/src/common/ops/populate/matmul_populate.cc +++ b/mindspore-lite/src/common/ops/populate/matmul_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" using mindspore::schema::PrimitiveType_MatMulFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/mfcc_populate.cc b/mindspore-lite/src/common/ops/populate/mfcc_populate.cc index 3b7fc3d8..854872c6 100644 --- a/mindspore-lite/src/common/ops/populate/mfcc_populate.cc +++ b/mindspore-lite/src/common/ops/populate/mfcc_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" using mindspore::schema::PrimitiveType_Mfcc; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/mul_populate.cc b/mindspore-lite/src/common/ops/populate/mul_populate.cc index 3b3c5df0..3524d842 100644 --- a/mindspore-lite/src/common/ops/populate/mul_populate.cc +++ b/mindspore-lite/src/common/ops/populate/mul_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_MulFusion; diff --git a/mindspore-lite/src/common/ops/populate/nllloss_populate.cc b/mindspore-lite/src/common/ops/populate/nllloss_populate.cc index 9a3c9f44..86814c60 100644 --- a/mindspore-lite/src/common/ops/populate/nllloss_populate.cc +++ b/mindspore-lite/src/common/ops/populate/nllloss_populate.cc @@ -16,7 +16,7 @@ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" using mindspore::schema::PrimitiveType_NLLLoss; using mindspore::schema::PrimitiveType_NLLLossGrad; diff --git a/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc b/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc index 485ff9c2..c9db6507 100644 --- a/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc +++ b/mindspore-lite/src/common/ops/populate/non_max_suppression_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" using mindspore::schema::PrimitiveType_NonMaxSuppression; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/one_hot_populate.cc b/mindspore-lite/src/common/ops/populate/one_hot_populate.cc index 18caaa3d..efd737dc 100644 --- a/mindspore-lite/src/common/ops/populate/one_hot_populate.cc +++ b/mindspore-lite/src/common/ops/populate/one_hot_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" using mindspore::schema::PrimitiveType_OneHot; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/p_relu_populate.cc b/mindspore-lite/src/common/ops/populate/p_relu_populate.cc index cda27de9..48a544d7 100644 --- a/mindspore-lite/src/common/ops/populate/p_relu_populate.cc +++ b/mindspore-lite/src/common/ops/populate/p_relu_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" using mindspore::schema::PrimitiveType_PReLUFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/pad_populate.cc b/mindspore-lite/src/common/ops/populate/pad_populate.cc index ac417186..2ea13a65 100644 --- a/mindspore-lite/src/common/ops/populate/pad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/pad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" using mindspore::schema::PrimitiveType_PadFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/partial_populate.cc b/mindspore-lite/src/common/ops/populate/partial_populate.cc index b5516686..cab3b01e 100644 --- a/mindspore-lite/src/common/ops/populate/partial_populate.cc +++ b/mindspore-lite/src/common/ops/populate/partial_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" using mindspore::schema::PrimitiveType_PartialFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/pooling_populate.cc b/mindspore-lite/src/common/ops/populate/pooling_populate.cc index dd9fd519..7f401afd 100644 --- a/mindspore-lite/src/common/ops/populate/pooling_populate.cc +++ b/mindspore-lite/src/common/ops/populate/pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" using mindspore::schema::PrimitiveType_AvgPoolFusion; using mindspore::schema::PrimitiveType_MaxPoolFusion; diff --git a/mindspore-lite/src/common/ops/populate/populate_register.h b/mindspore-lite/src/common/ops/populate/populate_register.h index 226c58d4..428849b5 100644 --- a/mindspore-lite/src/common/ops/populate/populate_register.h +++ b/mindspore-lite/src/common/ops/populate/populate_register.h @@ -20,7 +20,7 @@ #include #include #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/prim_util.h" diff --git a/mindspore-lite/src/common/ops/populate/power_populate.cc b/mindspore-lite/src/common/ops/populate/power_populate.cc index 2559626b..33450e87 100644 --- a/mindspore-lite/src/common/ops/populate/power_populate.cc +++ b/mindspore-lite/src/common/ops/populate/power_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" using mindspore::schema::PrimitiveType_PowFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/prior_box_populate.cc b/mindspore-lite/src/common/ops/populate/prior_box_populate.cc index 60e66233..19766859 100644 --- a/mindspore-lite/src/common/ops/populate/prior_box_populate.cc +++ b/mindspore-lite/src/common/ops/populate/prior_box_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/prior_box_parameter.h" +#include "nnacl_c/prior_box_parameter.h" using mindspore::schema::PrimitiveType_PriorBox; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc b/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc index 028de9f3..35035e8d 100644 --- a/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc +++ b/mindspore-lite/src/common/ops/populate/quant_dtype_cast_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::schema::PrimitiveType_QuantDTypeCast; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/random_normal_populate.cc b/mindspore-lite/src/common/ops/populate/random_normal_populate.cc index 79566406..5c92ead7 100644 --- a/mindspore-lite/src/common/ops/populate/random_normal_populate.cc +++ b/mindspore-lite/src/common/ops/populate/random_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_RandomNormal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc b/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc index f432ae45..e4da6342 100644 --- a/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc +++ b/mindspore-lite/src/common/ops/populate/random_standard_normal_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_RandomStandardNormal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/range_populate.cc b/mindspore-lite/src/common/ops/populate/range_populate.cc index 11325677..6bb081a0 100644 --- a/mindspore-lite/src/common/ops/populate/range_populate.cc +++ b/mindspore-lite/src/common/ops/populate/range_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/range_parameter.h" using mindspore::schema::PrimitiveType_Range; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reduce_populate.cc b/mindspore-lite/src/common/ops/populate/reduce_populate.cc index da4d3917..661b091e 100644 --- a/mindspore-lite/src/common/ops/populate/reduce_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reduce_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" using mindspore::schema::PrimitiveType_ReduceFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reduce_scatter.cc b/mindspore-lite/src/common/ops/populate/reduce_scatter.cc index 1e02e6e8..0025cd57 100644 --- a/mindspore-lite/src/common/ops/populate/reduce_scatter.cc +++ b/mindspore-lite/src/common/ops/populate/reduce_scatter.cc @@ -16,7 +16,7 @@ #include "schema/ops_generated.h" #include "schema/model_generated.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_ReduceScatter; diff --git a/mindspore-lite/src/common/ops/populate/reshape_populate.cc b/mindspore-lite/src/common/ops/populate/reshape_populate.cc index d34bcb8a..82e6fab5 100644 --- a/mindspore-lite/src/common/ops/populate/reshape_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reshape_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" using mindspore::schema::PrimitiveType_Reshape; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/resize_populate.cc b/mindspore-lite/src/common/ops/populate/resize_populate.cc index 0d8ae5da..a46cdd60 100644 --- a/mindspore-lite/src/common/ops/populate/resize_populate.cc +++ b/mindspore-lite/src/common/ops/populate/resize_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_Resize; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reverse_populate.cc b/mindspore-lite/src/common/ops/populate/reverse_populate.cc index 59ef7d23..3d16522a 100644 --- a/mindspore-lite/src/common/ops/populate/reverse_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reverse_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/reverse_fp32.h" +#include "nnacl_c/fp32/reverse_fp32.h" using mindspore::schema::PrimitiveType_ReverseV2; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc b/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc index 761457d6..5896c32d 100644 --- a/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc +++ b/mindspore-lite/src/common/ops/populate/reverse_sequence_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reverse_sequence_parameter.h" +#include "nnacl_c/reverse_sequence_parameter.h" using mindspore::schema::PrimitiveType_ReverseSequence; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc b/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc index cf1f9d6f..c47a449b 100644 --- a/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc +++ b/mindspore-lite/src/common/ops/populate/roi_pooling_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" using mindspore::schema::PrimitiveType_ROIPooling; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scale_populate.cc b/mindspore-lite/src/common/ops/populate/scale_populate.cc index 530543c2..780696d9 100644 --- a/mindspore-lite/src/common/ops/populate/scale_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scale_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" using mindspore::schema::PrimitiveType_ScaleFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc index 135fd261..fe7b8343 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_element_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_elements_parameter.h" +#include "nnacl_c/scatter_elements_parameter.h" using mindspore::schema::PrimitiveType_ScatterElements; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc index 7aa054ea..897bd854 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" using mindspore::schema::PrimitiveType_ScatterNd; diff --git a/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc b/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc index c8002c22..e75246af 100644 --- a/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc +++ b/mindspore-lite/src/common/ops/populate/scatter_nd_update_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" using mindspore::schema::PrimitiveType_ScatterNdUpdate; using mindspore::schema::PrimitiveType_TensorScatterAdd; diff --git a/mindspore-lite/src/common/ops/populate/slice_populate.cc b/mindspore-lite/src/common/ops/populate/slice_populate.cc index c41899eb..270029f7 100644 --- a/mindspore-lite/src/common/ops/populate/slice_populate.cc +++ b/mindspore-lite/src/common/ops/populate/slice_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" using mindspore::schema::PrimitiveType_SliceFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/softmax_populate.cc b/mindspore-lite/src/common/ops/populate/softmax_populate.cc index 66ccdc3b..8821d3a5 100644 --- a/mindspore-lite/src/common/ops/populate/softmax_populate.cc +++ b/mindspore-lite/src/common/ops/populate/softmax_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" using mindspore::schema::PrimitiveType_Softmax; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc index bc7b2799..201f7f0b 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_batch_nd_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" using mindspore::schema::PrimitiveType_SpaceToBatchND; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc index 95ef85ea..43ce9f35 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_batch_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" using mindspore::schema::PrimitiveType_SpaceToBatch; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc b/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc index 0b37c433..6454ac01 100644 --- a/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc +++ b/mindspore-lite/src/common/ops/populate/space_to_depth_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" using mindspore::schema::PrimitiveType_SpaceToDepth; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc b/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc index f970e47e..74ccfe43 100644 --- a/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sparse_softmax_cross_entropy_with_logits_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc b/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc index d4ac7f8e..f943e7b4 100644 --- a/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sparse_to_dense_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/sparse_to_dense_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" using mindspore::schema::PrimitiveType_SparseToDense; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/splice_populate.cc b/mindspore-lite/src/common/ops/populate/splice_populate.cc index 32767fde..f2b73334 100644 --- a/mindspore-lite/src/common/ops/populate/splice_populate.cc +++ b/mindspore-lite/src/common/ops/populate/splice_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/op_base.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/splice_parameter.h" using mindspore::schema::PrimitiveType_Splice; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/split_populate.cc b/mindspore-lite/src/common/ops/populate/split_populate.cc index 83e3e67f..8bb30d3e 100644 --- a/mindspore-lite/src/common/ops/populate/split_populate.cc +++ b/mindspore-lite/src/common/ops/populate/split_populate.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Split; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc b/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc index c4402762..485b3bb3 100644 --- a/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc +++ b/mindspore-lite/src/common/ops/populate/split_with_overlap_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" using mindspore::schema::PrimitiveType_SplitWithOverlap; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/squeeze_populate.cc b/mindspore-lite/src/common/ops/populate/squeeze_populate.cc index 8767f130..d4aad570 100644 --- a/mindspore-lite/src/common/ops/populate/squeeze_populate.cc +++ b/mindspore-lite/src/common/ops/populate/squeeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" using mindspore::schema::PrimitiveType_Squeeze; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/stack_populate.cc b/mindspore-lite/src/common/ops/populate/stack_populate.cc index 57bb5652..0094a8e2 100644 --- a/mindspore-lite/src/common/ops/populate/stack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/stack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" using mindspore::schema::PrimitiveType_Stack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc b/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc index 88e82ea1..09feb95c 100644 --- a/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc +++ b/mindspore-lite/src/common/ops/populate/strided_slice_grad_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" using mindspore::schema::PrimitiveType_StridedSliceGrad; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/strided_slice_populate.h b/mindspore-lite/src/common/ops/populate/strided_slice_populate.h index 552cc40b..8a3cc505 100644 --- a/mindspore-lite/src/common/ops/populate/strided_slice_populate.h +++ b/mindspore-lite/src/common/ops/populate/strided_slice_populate.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_STRIDED_SLICE_POPULATE_H_ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc b/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc index 536b2cc2..2992b1af 100644 --- a/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/custom_predict_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" using mindspore::schema::PrimitiveType_CustomPredict; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc b/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc index 05a1517e..750027b6 100644 --- a/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/lsh_projection_populate.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/common/ops/populate/populate_register.h" using mindspore::schema::PrimitiveType_LshProjection; diff --git a/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc b/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc index 3baf9537..6b02efaa 100644 --- a/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc +++ b/mindspore-lite/src/common/ops/populate/string/skip_gram_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" using mindspore::schema::PrimitiveType_SkipGram; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/sub_populate.cc b/mindspore-lite/src/common/ops/populate/sub_populate.cc index e9023698..be1f3a99 100644 --- a/mindspore-lite/src/common/ops/populate/sub_populate.cc +++ b/mindspore-lite/src/common/ops/populate/sub_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "src/common/ops/populate/arithmetic_populate.h" using mindspore::schema::PrimitiveType_SubFusion; diff --git a/mindspore-lite/src/common/ops/populate/tile_populate.cc b/mindspore-lite/src/common/ops/populate/tile_populate.cc index b1b555d8..faed1330 100644 --- a/mindspore-lite/src/common/ops/populate/tile_populate.cc +++ b/mindspore-lite/src/common/ops/populate/tile_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" using mindspore::schema::PrimitiveType_TileFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/topk_populate.cc b/mindspore-lite/src/common/ops/populate/topk_populate.cc index a92f5277..66446101 100644 --- a/mindspore-lite/src/common/ops/populate/topk_populate.cc +++ b/mindspore-lite/src/common/ops/populate/topk_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" using mindspore::schema::PrimitiveType_TopKFusion; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/transpose_populate.cc b/mindspore-lite/src/common/ops/populate/transpose_populate.cc index a13950b7..11e1c3e3 100644 --- a/mindspore-lite/src/common/ops/populate/transpose_populate.cc +++ b/mindspore-lite/src/common/ops/populate/transpose_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" using mindspore::schema::PrimitiveType_Transpose; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc b/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc index dcb6b512..2bc02a85 100644 --- a/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc +++ b/mindspore-lite/src/common/ops/populate/triu_tril_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/triu_tril_parameter.h" +#include "nnacl_c/triu_tril_parameter.h" using mindspore::schema::PrimitiveType_Tril; using mindspore::schema::PrimitiveType_Triu; diff --git a/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc b/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc index d901852a..a00aa15c 100644 --- a/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc +++ b/mindspore-lite/src/common/ops/populate/uniform_real_populate.cc @@ -15,7 +15,7 @@ */ #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::schema::PrimitiveType_UniformReal; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unique_populate.cc b/mindspore-lite/src/common/ops/populate/unique_populate.cc index 456f1c79..ebe42917 100644 --- a/mindspore-lite/src/common/ops/populate/unique_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unique_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/unique_fp32.h" +#include "nnacl_c/fp32/unique_fp32.h" using mindspore::schema::PrimitiveType_Unique; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc b/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc index 5feafc4c..d859556c 100644 --- a/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unsqueeze_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" using mindspore::schema::PrimitiveType_Unsqueeze; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/unstack_populate.cc b/mindspore-lite/src/common/ops/populate/unstack_populate.cc index 6602b08c..a77ded6c 100644 --- a/mindspore-lite/src/common/ops/populate/unstack_populate.cc +++ b/mindspore-lite/src/common/ops/populate/unstack_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/unstack_parameter.h" +#include "nnacl_c/unstack_parameter.h" using mindspore::schema::PrimitiveType_Unstack; namespace mindspore { diff --git a/mindspore-lite/src/common/ops/populate/where_populate.cc b/mindspore-lite/src/common/ops/populate/where_populate.cc index 48952038..f3a32b25 100644 --- a/mindspore-lite/src/common/ops/populate/where_populate.cc +++ b/mindspore-lite/src/common/ops/populate/where_populate.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/common/ops/populate/populate_register.h" -#include "nnacl/where_parameter.h" +#include "nnacl_c/where_parameter.h" using mindspore::schema::PrimitiveType_Where; namespace mindspore { diff --git a/mindspore-lite/src/common/prim_util.cc b/mindspore-lite/src/common/prim_util.cc index 7da14e5b..d640815d 100644 --- a/mindspore-lite/src/common/prim_util.cc +++ b/mindspore-lite/src/common/prim_util.cc @@ -17,7 +17,7 @@ #include "src/common/prim_util.h" #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "schema/model_generated.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/src/common/tensor_util.cc b/mindspore-lite/src/common/tensor_util.cc index aecc0236..f6a33f1f 100644 --- a/mindspore-lite/src/common/tensor_util.cc +++ b/mindspore-lite/src/common/tensor_util.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore { namespace lite { void FreeInTensorC(std::vector *tensors_in, const std::shared_ptr &allocator) { diff --git a/mindspore-lite/src/common/tensor_util.h b/mindspore-lite/src/common/tensor_util.h index bd75f74b..080bb448 100644 --- a/mindspore-lite/src/common/tensor_util.h +++ b/mindspore-lite/src/common/tensor_util.h @@ -20,11 +20,11 @@ #include #include #include "src/tensor.h" -#include "nnacl/tensor_c.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" #include "src/tensorlist.h" -#include "nnacl/infer/common_infer.h" -#include "nnacl/tensorlist_c_utils.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/tensorlist_c_utils.h" #include "src/litert/cxx_api/tensor/tensor_impl.h" #include "include/api/visible.h" diff --git a/mindspore-lite/src/control_flow/control_flow_scheduler.cc b/mindspore-lite/src/control_flow/control_flow_scheduler.cc index a2a573de..bfa80d91 100644 --- a/mindspore-lite/src/control_flow/control_flow_scheduler.cc +++ b/mindspore-lite/src/control_flow/control_flow_scheduler.cc @@ -20,7 +20,7 @@ #include #include "src/litert/kernel_exec_util.h" #include "src/litert/kernel/cpu/base/partial_fusion.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #include "src/control_flow/kernel/exit_subgraph_kernel.h" #include "src/control_flow/kernel/identity_kernel.h" #include "src/tensorlist.h" diff --git a/mindspore-lite/src/control_flow/control_flow_scheduler.h b/mindspore-lite/src/control_flow/control_flow_scheduler.h index 2344f380..910a6aa1 100644 --- a/mindspore-lite/src/control_flow/control_flow_scheduler.h +++ b/mindspore-lite/src/control_flow/control_flow_scheduler.h @@ -26,7 +26,7 @@ #include #include "src/common/utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "src/executor/sub_graph_kernel.h" diff --git a/mindspore-lite/src/executor/kernel_exec.h b/mindspore-lite/src/executor/kernel_exec.h index 89a32b5f..37eacf3e 100644 --- a/mindspore-lite/src/executor/kernel_exec.h +++ b/mindspore-lite/src/executor/kernel_exec.h @@ -26,7 +26,7 @@ #ifdef ENABLE_ARM #include #endif -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/executor/sub_graph_kernel.h b/mindspore-lite/src/executor/sub_graph_kernel.h index 13f1ef19..8f84a021 100644 --- a/mindspore-lite/src/executor/sub_graph_kernel.h +++ b/mindspore-lite/src/executor/sub_graph_kernel.h @@ -29,7 +29,7 @@ #include "src/common/version_manager.h" #include "src/litert/cpu_info.h" #if defined(ENABLE_ARM) && defined(ENABLE_FP16) -#include "nnacl/constant_of_shape_parameter.h" +#include "nnacl_c/constant_of_shape_parameter.h" #endif namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc index e8ad77ec..d9233c4a 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc @@ -24,7 +24,7 @@ #include "ops/primitive_c.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "extendrt/cxx_api/model/model_impl.h" #include "extendrt/cxx_api/dlutils.h" diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc index 72712ed0..4fcdace8 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc @@ -17,7 +17,7 @@ #include "src/extendrt/cxx_api/model_pool/runner_config.h" #include "src/common/log_adapter.h" #include "src/litert/cpu_info.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef CAPTURE_SIGNALS #include "src/extendrt/signal_handler.h" #endif diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc index 5cb85137..24c8833d 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc @@ -18,7 +18,7 @@ #include #include #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/extendrt/cxx_api/model_pool/resource_manager.h" #include "src/common/log_adapter.h" #include "include/lite_types.h" diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc index d36025ac..abed1765 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_worker.cc @@ -18,7 +18,7 @@ #include "src/common/log_adapter.h" #include "src/extendrt/numa_adapter.h" #include "src/common/common.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { void ModelWorker::PrintWorkerInfo() { MS_LOG(ERROR) << "worker id: " << worker_config_->worker_id diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc index a14f96b1..a22224fd 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc @@ -22,7 +22,7 @@ #include "src/common/log_adapter.h" #include "src/common/utils.h" #include "src/extendrt/numa_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace { constexpr int kNumIndex = 2; diff --git a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc b/mindspore-lite/src/extendrt/delegate/delegate_utils.cc index 25ad349c..59081be9 100644 --- a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc +++ b/mindspore-lite/src/extendrt/delegate/delegate_utils.cc @@ -15,7 +15,7 @@ */ #include "src/extendrt/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { bool IsSubGraphInputTensor(const std::vector &inputs, const TensorInfo &input) { return std::find(inputs.begin(), inputs.end(), input) != inputs.end(); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt index 452b4342..cdb9661d 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt @@ -12,7 +12,7 @@ set(TENSORRT_PATH $ENV{TENSORRT_PATH}) set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib) include_directories(${TENSORRT_PATH}/include) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) include_directories(${CCSRC_DIR}/../) include_directories(${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops) @@ -58,7 +58,7 @@ file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false ${CMAKE_CURRENT_SOURCE_DIR}/cuda_impl/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../../extendrt/delegate/delegate_utils.cc ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/cuda_device_info.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc ) diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu index e5150eda..f1ee734b 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu @@ -17,7 +17,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" __device__ __forceinline__ uint64_t Pop(const uint64_t *chunks, uint64_t *curr_chunk, uint8_t bit_count, int32_t *curr_bit_count, int32_t *curr_chunk_index) { diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc index 608c3f4e..d8ade395 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc @@ -16,7 +16,7 @@ #include "src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h" #include -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc index 412913d0..82233574 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc @@ -16,7 +16,7 @@ #include "src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h" #include -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc index 5856f63d..afddf56b 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc @@ -17,7 +17,7 @@ #include "src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h" #include #include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc index 3d82f8d1..b9cddd0a 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #include "infer/resize.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc index eebb1552..aa99d260 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc @@ -17,7 +17,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/anfalgo.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h index b5787b4e..cee6582f 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h +++ b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h @@ -27,7 +27,7 @@ #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" #include "ir/dtype/type_id.h" #include "schema/ops_generated.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "include/api/context.h" #include "mindapi/base/types.h" diff --git a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc index 67302c5d..e8527d0d 100644 --- a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc +++ b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc @@ -26,7 +26,7 @@ #include "src/common/draw/drawer.h" #include "src/extendrt/kernel/nnacl/nnacl_base_kernel.h" #include "src/extendrt/kernel/extendrt_kernel_exec.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" #include "extendrt/delegate/ascend_native/delegate.h" #include "extendrt/delegate/factory.h" diff --git a/mindspore-lite/src/extendrt/infer_session.cc b/mindspore-lite/src/extendrt/infer_session.cc index 3648c012..2e74d185 100644 --- a/mindspore-lite/src/extendrt/infer_session.cc +++ b/mindspore-lite/src/extendrt/infer_session.cc @@ -23,7 +23,7 @@ #include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h" #include "extendrt/delegate/plugin/ascend_native_executor_plugin.h" #include "extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace { diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc b/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc index d3f531f6..f3b59695 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc +++ b/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc @@ -16,7 +16,7 @@ #include "extendrt/kernel/ascend/model/dyn_shape_process.h" #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc index dfef81ad..aafd9a42 100644 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc +++ b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc @@ -20,7 +20,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "common/ms_factory.h" #include "include/api/status.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h index 94dd8fa7..c1b0b874 100644 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h +++ b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h @@ -22,7 +22,7 @@ #include #include #include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "common/common_utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc b/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc index ee4f07ac..e9ff4be9 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc +++ b/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc @@ -16,7 +16,7 @@ #include "src/extendrt/kernel/cuda/batchtospace.h" #include -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::kernel { int BatchtoSpaceCudaKernel::Prepare() { diff --git a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc index a1b40bc2..eec08314 100644 --- a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc +++ b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_util.cc @@ -23,7 +23,7 @@ #include "ir/value.h" #include "ir/tensor_new.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_util.h" #include "ir/tensor_api.h" diff --git a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h index 3fc1e267..404acb0e 100644 --- a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h +++ b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/arithmetic_populate.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ #define MINDSPORE_LITE_SRC_COMMON_OPS_POPULATE_ARITHMETIC_POPULATE_H_ -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { ArithmeticParameter *PopulateArithmeticCommonPara(void *prim); diff --git a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h index 5373ecd7..1a6e1251 100644 --- a/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h +++ b/mindspore-lite/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.h @@ -22,7 +22,7 @@ #include #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "src/common/prim_util.h" diff --git a/mindspore-lite/src/infer/primitive_type.cc b/mindspore-lite/src/infer/primitive_type.cc index 312e0562..9eaaf430 100644 --- a/mindspore-lite/src/infer/primitive_type.cc +++ b/mindspore-lite/src/infer/primitive_type.cc @@ -15,7 +15,7 @@ */ #include "src/infer/primitive_type.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::kernel { #ifdef ENABLE_CLOUD_INFERENCE diff --git a/mindspore-lite/src/litert/cpu_info.cc b/mindspore-lite/src/litert/cpu_info.cc index 7510de1a..1302fe02 100644 --- a/mindspore-lite/src/litert/cpu_info.cc +++ b/mindspore-lite/src/litert/cpu_info.cc @@ -18,7 +18,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/nnacl_utils.h" +#include "nnacl_c/nnacl_utils.h" #if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(MS_COMPILE_IOS) #include #include diff --git a/mindspore-lite/src/litert/cpu_info.h b/mindspore-lite/src/litert/cpu_info.h index d6ac9f76..48f51c50 100644 --- a/mindspore-lite/src/litert/cpu_info.h +++ b/mindspore-lite/src/litert/cpu_info.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_CPU_INFO_H_ #if defined(ENABLE_AVX512) || defined(ENABLE_AVX) -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #endif inline bool PlatformInstructionSetSupportCheck() { diff --git a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc index 3ea67a2e..20dfd8d4 100644 --- a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc +++ b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.cc @@ -15,7 +15,7 @@ */ #include "src/litert/delegate/coreml/op/coreml_op.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::lite { int CoreMLOp::Init() { auto ret = InitParams(); diff --git a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h index d23dddb9..31d6982b 100644 --- a/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h +++ b/mindspore-lite/src/litert/delegate/coreml/op/coreml_op.h @@ -30,7 +30,7 @@ #include "include/api/data_type.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NOT_SUPPORT; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/delegate/delegate_utils.cc b/mindspore-lite/src/litert/delegate/delegate_utils.cc index 41d7ea93..c9aeeb11 100644 --- a/mindspore-lite/src/litert/delegate/delegate_utils.cc +++ b/mindspore-lite/src/litert/delegate/delegate_utils.cc @@ -15,7 +15,7 @@ */ #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { int hw8 = plane / C8NUM * C8NUM; diff --git a/mindspore-lite/src/litert/delegate/delegate_utils.h b/mindspore-lite/src/litert/delegate/delegate_utils.h index 7aaa9938..5843699c 100644 --- a/mindspore-lite/src/litert/delegate/delegate_utils.h +++ b/mindspore-lite/src/litert/delegate/delegate_utils.h @@ -19,7 +19,7 @@ #include "include/api/delegate.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { bool IsSubGraphInputTensor(const std::vector &inputs, mindspore::MSTensor input); diff --git a/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt b/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt index 5383d643..9d94e0e4 100644 --- a/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt +++ b/mindspore-lite/src/litert/delegate/npu/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${DDK_PATH}) -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) file(GLOB_RECURSE NPU_RUNTIME_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc diff --git a/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h b/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h index 6ffa3975..d76c6f2f 100644 --- a/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h +++ b/mindspore-lite/src/litert/delegate/npu/npu_converter_utils.h @@ -29,7 +29,7 @@ #include "include/api/data_type.h" #include "include/graph/op/all_ops.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { enum NCHW_SHAPE { NCHW_INVALID = -1, NCHW_N = 0, NCHW_C = 1, NCHW_H = 2, NCHW_W = 3 }; diff --git a/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc b/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc index 1b344a90..0b2e176c 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc +++ b/mindspore-lite/src/litert/delegate/npu/op/convolution_base_npu.cc @@ -18,7 +18,7 @@ #include "src/litert/delegate/npu/npu_converter_utils.h" #include "src/litert/delegate/npu/transpose_kernel.h" #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/int8/pack_int8.h" +#include "nnacl_c/int8/pack_int8.h" namespace mindspore::lite { ConvolutionBaseNPUOp::~ConvolutionBaseNPUOp() { diff --git a/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc b/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc index 02bf60fa..1c171de0 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc +++ b/mindspore-lite/src/litert/delegate/npu/op/deconvolution_npu.cc @@ -16,7 +16,7 @@ #include "src/litert/delegate/npu/op/deconvolution_npu.h" #include "src/litert/delegate/npu/npu_converter_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/delegate/npu/op/npu_op.h b/mindspore-lite/src/litert/delegate/npu/op/npu_op.h index 9628bf86..215d445c 100644 --- a/mindspore-lite/src/litert/delegate/npu/op/npu_op.h +++ b/mindspore-lite/src/litert/delegate/npu/op/npu_op.h @@ -28,7 +28,7 @@ #include "include/api/data_type.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NOT_SUPPORT; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc b/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc index 36e9409c..38c826aa 100644 --- a/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc +++ b/mindspore-lite/src/litert/delegate/npu/transpose_kernel.cc @@ -18,7 +18,7 @@ #include "src/litert/delegate/npu/npu_converter_utils.h" #include "src/litert/delegate/npu/op/npu_op.h" #include "src/litert/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::lite { int TransposeNPUKernel::Execute() { if (perm_ != NHWC2NCHW_PERM && perm_ != NCHW2NHWC_PERM) { diff --git a/mindspore-lite/src/litert/infer_manager.cc b/mindspore-lite/src/litert/infer_manager.cc index d8a240b4..6d7e7c20 100644 --- a/mindspore-lite/src/litert/infer_manager.cc +++ b/mindspore-lite/src/litert/infer_manager.cc @@ -23,7 +23,7 @@ #include "src/litert/cxx_api/tensor/tensor_impl.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "src/tensorlist.h" #include "include/registry/register_kernel_interface.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/infer_manager.h b/mindspore-lite/src/litert/infer_manager.h index 39465bfe..9a8766cc 100644 --- a/mindspore-lite/src/litert/infer_manager.h +++ b/mindspore-lite/src/litert/infer_manager.h @@ -24,8 +24,8 @@ #include #include "src/common/prim_util.h" #include "src/tensor.h" -#include "nnacl/tensor_c.h" -#include "nnacl/infer/infer.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/infer/infer.h" #include "include/api/kernel.h" #include "include/api/allocator.h" diff --git a/mindspore-lite/src/litert/inner_context.h b/mindspore-lite/src/litert/inner_context.h index 88281eb1..e5f02fb4 100644 --- a/mindspore-lite/src/litert/inner_context.h +++ b/mindspore-lite/src/litert/inner_context.h @@ -27,8 +27,8 @@ #include "src/litert/inner_allocator.h" #endif #include "thread/threadpool.h" -#include "nnacl/op_base.h" -#include "nnacl/kernel.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" #ifdef ENABLE_ARM #include "src/litert/cpu_info.h" #endif diff --git a/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc b/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc index 3954ed42..e46cc6c9 100644 --- a/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc +++ b/mindspore-lite/src/litert/kernel/ascend/src/custom_interface.cc @@ -19,7 +19,7 @@ #include "include/errorcode.h" #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::kernel { namespace acl { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc index 4455901e..37541ed7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/base/arithmetic_base.h" +#include "nnacl_c/base/arithmetic_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h index f436e8df..3e3ad63d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::kernel { class ArithmeticBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h b/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h index bc706e4f..b2032cab 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/constant_of_shape.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/constant_of_shape_parameter.h" -#include "nnacl/fp32/constant_of_shape_fp32.h" -#include "nnacl/fp16/constant_of_shape_fp16.h" +#include "nnacl_c/constant_of_shape_parameter.h" +#include "nnacl_c/fp32/constant_of_shape_fp32.h" +#include "nnacl_c/fp16/constant_of_shape_fp16.h" namespace mindspore::kernel { class ConstantOfShapeCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc index edffea42..44ac568e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_is_inf.cc @@ -17,7 +17,7 @@ #include "include/errorcode.h" #include "src/litert/kernel/cpu/base/custom_is_inf.h" #include "src/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc index 85cfeab6..cb384f0a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_masked_fill.cc @@ -17,7 +17,7 @@ #include "include/errorcode.h" #include "src/litert/kernel/cpu/base/custom_masked_fill.h" #include "src/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc index e118e8c1..be974c84 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc index 46477ac9..a13c9c90 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.cc @@ -20,7 +20,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h index 3f479827..41e0458b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/detection_post_process_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc index 56ddd255..7f77e978 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/base/format_transpose.h" -#include "nnacl/base/format_transpose.h" +#include "nnacl_c/base/format_transpose.h" #include "src/litert/kernel_registry.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h index 062bee69..1468808d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/format_transpose.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" namespace mindspore::kernel { class FormatTransposeCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h index f4d34d38..e11b85dd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_base.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #include "src/litert/kernel/cpu/base/group_convolution_creator.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h index 0afaa11b..c362a7ef 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/group_convolution_creator.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/tensor_category.h" #include "include/api/allocator.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h b/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h index ee1b4c2d..db0fb59f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/layout_transform.h @@ -20,7 +20,7 @@ #ifdef ENABLE_FP16 #include #endif -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "src/tensor.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc b/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc index ad2898e0..812f6dc3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/base/quant_dtype_cast.h" #include -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h b/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h index b352386b..91ba1db0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/random_normal.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h b/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h index b38421c0..963e8085 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/reduce_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::kernel { class ReduceBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h b/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h index c82068bc..1f94416e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/resize_base.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::schema::PrimitiveType_Resize; using mindspore::schema::ResizeMethod; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h index 6376f26e..043481f9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" namespace mindspore::kernel { class ScatterNDCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h index 36db5771..88972f13 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/scatter_nd_binary.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/scatter_nd_binary.h" +#include "nnacl_c/base/scatter_nd_binary.h" namespace mindspore::kernel { constexpr int kScatterUpdateInputIndex = 0; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_base.h b/mindspore-lite/src/litert/kernel/cpu/base/split_base.h index 582905e3..8ec8ce5c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_base.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/split_parameter.h" -#include "nnacl/base/split_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/base/split_base.h" namespace mindspore::kernel { class SplitBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc index 88efe29b..77cf39ce 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.cc @@ -17,7 +17,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "src/tensor.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h index df3e84ea..dc674980 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/split_with_over_lap_base.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/executor/kernel_exec.h" -#include "nnacl/split_parameter.h" -#include "nnacl/base/split_with_over_lap_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/base/split_with_over_lap_base.h" namespace mindspore::kernel { class SplitWithOverlapBaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h b/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h index 930da7da..d664ab2b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/base/transpose_base.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_BASE_TRANSPOSE_BASE_H_ #include -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc index 28203bed..233545ab 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.cc @@ -16,7 +16,7 @@ #include "bolt/bolt_parameter_manager.h" #include "bolt/bolt_utils.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "schema/ops_generated.h" namespace mindspore::kernel::bolt { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h index b0702b92..87720f64 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_parameter_manager.h @@ -19,7 +19,7 @@ #include #include "bolt/common/uni/include/parameter_spec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" namespace mindspore::kernel::bolt { diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h index 47638204..00da8cdd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/bolt_utils.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BOLT_BOLT_UTILS_H_ #include "bolt/common/memory/include/tensor.hpp" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "bolt/common/uni/include/parameter_spec.h" typedef Tensor BoltTensor; diff --git a/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc b/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc index c7ea2639..b7738db3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc +++ b/mindspore-lite/src/litert/kernel/cpu/bolt/convolution_bolt.cc @@ -16,8 +16,8 @@ #include "bolt/convolution_bolt.h" #include "bolt/bolt_kernel_manager.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pack.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pack.h" #include "bolt/compute/tensor/include/tensor_computing.h" #include "bolt/common/memory/include/tensor_desc.h" #include "bolt/bolt_tensor_utils.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h b/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h index 3d691bec..6413ccdb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensor_array.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/tensor_array_parameter.h" +#include "nnacl_c/tensor_array_parameter.h" #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h index c4f8836a..0cbf87fe 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_fromtensor.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListFromTensorCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h index d2d81699..374c8234 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_getitem.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListGetItemCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h index 884e0498..4fdb8d92 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_reserve.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListReserveCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h index 032646bc..f3f516ba 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListSetItemCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h index f97c09d8..85d3f03f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h +++ b/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_stack.h @@ -22,7 +22,7 @@ #include "src/litert/lite_kernel.h" #include "src/tensorlist.h" #include "schema/model_generated.h" -#include "nnacl/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_parameter.h" namespace mindspore::kernel { class TensorListStackCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h index 25d977fd..c4b838cd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/biasadd_fp16.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_BIASADD_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" namespace mindspore::kernel { class BiasAddCPUFp16Kernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h index c88a68e1..447daff3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/cast_fp16.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::kernel { class CastFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc index 47ac4de3..1a1d5dc6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/common_fp16.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp16/common_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc index 536b962c..493f7fa0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/litert/kernel/cpu/fp16/layout_transform_fp16.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h index 76cde5b9..640e789a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_1x1_fp16.h @@ -22,8 +22,8 @@ #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/common/utils.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp16/matmul_fp16.h" namespace mindspore::kernel { class Convolution1x1FP16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc index 42288a21..53d04f42 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.cc @@ -24,7 +24,7 @@ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h" #include "src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h" #include "src/litert/kernel/cpu/base/group_convolution_creator.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/base/conv_common_base.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h index c1581f52..94c294d3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_delegate_fp16.h @@ -19,8 +19,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" #define WEIGHT_NEED_FREE 0001 #define BIAS_NEED_FREE 0010 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc index 435560d7..88e9f720 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.cc @@ -17,8 +17,8 @@ #ifdef ENABLE_ARM #include "src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h" #include "include/errorcode.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h index 30383724..28a8714a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_3x3_fp16.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3Fp16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc index 04312bb6..b38c694a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h index 706faeec..bc8ded4b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc index bdeb2e57..50edd50d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h index c944041e..a535bedf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_depthwise_slidewindow_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc index cebd9ea7..a8d04577 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_fp16.cc @@ -17,11 +17,11 @@ #include "src/litert/kernel/cpu/fp16/convolution_fp16.h" #include #include "include/errorcode.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h index f15c93c0..c12cf980 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/convolution_winograd_fp16.h @@ -21,10 +21,10 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" #include "src/common/utils.h" -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" namespace mindspore::kernel { class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc index 14871883..4dda8d7a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/custom_gru_fp16.cc @@ -20,10 +20,10 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/fp16/custom_gru_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/fp16/custom_gru_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc index b42ccdc2..9059260b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h index 58557220..ea28cb1c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_depthwise_fp16.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp16/conv_depthwise_fp16.h" +#include "nnacl_c/fp16/conv_depthwise_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h index 4a39cbe4..29666797 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_fp16.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_DECONVOLUTION_FP16_H_ #include -#include "nnacl/fp16/deconv_fp16.h" -#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/deconv_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/base/convolution_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h index 12120d9c..6e276d3f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/deconvolution_winograd_fp16.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" -#include "nnacl/fp16/common_func_fp16.h" -#include "nnacl/fp16/deconv_winograd_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/common_func_fp16.h" +#include "nnacl_c/fp16/deconv_winograd_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc index 5857eeae..0722c9d6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/dynamic_quant_fp16.cc @@ -19,9 +19,9 @@ #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/dynamic_quant_parameter.h" -#include "nnacl/fp16/dynamic_quant_fp16.h" -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/dynamic_quant_parameter.h" +#include "nnacl_c/fp16/dynamic_quant_fp16.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h b/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h index b7656eeb..7447359d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/fp16_op_handler.h @@ -16,10 +16,10 @@ #ifdef ENABLE_ARM #include #ifdef ENABLE_FP16 -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #endif #endif -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h index ca6d09f4..1441abc4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/group_convolution_fp16.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" -#include "nnacl/fp16/conv_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" namespace mindspore::kernel { class GroupConvolutionFP16CPUKernel : public GroupConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc index c971a33b..5bdbdd9f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.cc @@ -18,10 +18,10 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/gru_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp16/gru_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h index baf5191a..c5bf5dac 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/gru_fp16.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_GRU_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/gru_parameter.h" +#include "nnacl_c/gru_parameter.h" namespace mindspore::kernel { class GruFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc index 0c599fdf..9dc1d9ca 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.cc @@ -17,9 +17,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/instance_norm_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/instance_norm_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h index da58e101..31009de6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/instance_norm_fp16.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_INSTANCE_NORM_FP16_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc index 73ec2025..361a0870 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/layout_transform_fp16.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp16/layout_transform_fp16.h" -#include "nnacl/fp16/pack_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" #include "src/common/log_adapter.h" #include "schema/ops_types_generated.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc index e358e58b..2d82ff32 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" #include -#include "nnacl/fp16/lstm_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h index a5c15548..f68ac600 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::kernel { class LstmFp16BaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc index b8100db0..4977df7b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc index c6adc6aa..cf7a32e4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" -#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/lstm_fp16.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc index 982daeef..9f01f0f5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/fp16/matmul_base_fp16.h" #include -#include "nnacl/fp16/matmul_fp16.h" -#include "nnacl/fp16/cast_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" #include "include/errorcode.h" using mindspore::lite::kCHWDimNumber; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h index fd7a27d6..3a467f34 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.h @@ -23,7 +23,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/common/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class MatmulBaseFP16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc index a8428c39..4841a3bb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp16/quant_dtype_cast_fp16.h" #include -#include "nnacl/int8/quant_dtype_cast_int8.h" -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h b/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h index cdae563f..faf0944b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16/resize_fp16.h @@ -19,7 +19,7 @@ #include #include #include "src/litert/kernel/cpu/fp32/resize_fp32.h" -#include "nnacl/fp16/resize_fp16.h" +#include "nnacl_c/fp16/resize_fp16.h" namespace mindspore::kernel { class ResizeFp16CPUKernel : public ResizeCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc index 8d54bb34..f6d1a1d2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h index 0b81bd69..ee88b288 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" namespace mindspore::kernel { class ActivationGradCPUKernelFp16 : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc index 15aa2fe4..064eb5f3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/arithmetic_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h index 114a8e5d..20c45ceb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" #include "schema/model_generated.h" using mindspore::schema::PrimitiveType_AddGrad; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h index 105db40e..ac5e6e73 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_self_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" namespace mindspore::kernel { class ArithmeticSelfGradFp16CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h index 685ede33..24c72273 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bias_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/executor/kernel_exec.h" -#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" namespace mindspore::kernel { class BiasGradCPUKernelFp16 : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc index 25787bad..b8817f9e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.cc @@ -24,7 +24,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/batch_norm.h" +#include "nnacl_c/fp16_grad/batch_norm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h index 7c9757bc..6b931821 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/bn_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/executor/kernel_exec.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc index 6eeaf07c..29b3f479 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.cc @@ -16,11 +16,11 @@ #include "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp16_grad/convolution_grad_filter.h" -#include "nnacl/fp16_grad/pack_fp16_ext.h" -#include "nnacl/fp16_grad/gemm_fp16.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp16_grad/convolution_grad_filter.h" +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc index 3ac59832..80c65cb7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.cc @@ -16,10 +16,10 @@ #include "src/litert/kernel/cpu/fp16_grad/convolution_fp16_grad_input.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp16_grad/pack_fp16_ext.h" -#include "nnacl/fp16_grad/gemm_fp16.h" -#include "nnacl/fp16_grad/convolution_grad_input.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include "nnacl_c/fp16_grad/convolution_grad_input.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc index e37547ed..86ad09d5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.cc @@ -16,11 +16,11 @@ #include #include "src/litert/kernel/cpu/fp16_grad/dropout_fp16_grad.h" -#include "nnacl/fp16_grad/dropout_grad.h" +#include "nnacl_c/fp16_grad/dropout_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc index 6576ac15..a77b764a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/layernorm_fp16_grad.cc @@ -19,8 +19,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/layernorm_grad.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/fp16_grad/layernorm_grad.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc index 7a56be8e..10cf6d17 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/neg_fp16_grad.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp16/arithmetic_self_fp16.h" +#include "nnacl_c/fp16/arithmetic_self_fp16.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc index 1cd92cf0..2dc0d177 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16/pooling_fp16.h" -#include "nnacl/fp16_grad/pooling_grad.h" +#include "nnacl_c/fp16/pooling_fp16.h" +#include "nnacl_c/fp16_grad/pooling_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h index 678b1a3c..10bf78d7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/pooling_fp16_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::kernel { using mindspore::schema::PadMode; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc index f1050657..57064d04 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.cc @@ -16,8 +16,8 @@ #include #include "src/litert/kernel/cpu/fp16_grad/resize_fp16_grad.h" -#include "nnacl/fp16_grad/resize_grad.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp16_grad/resize_grad.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc index b6db3387..ca0fb484 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/fp16_grad/strided_slice_grad.h" #include "src/common/ops/populate/strided_slice_populate.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h index f90a5915..4c901547 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/strided_slice_fp16_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_GRAD_STRIDED_SLICE_FP16_GRAD_H_ #include -#include "nnacl/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/fp16_grad/strided_slice_grad.h" #include "src/executor/kernel_exec.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc index b7c1703f..f6e1eefc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp16_grad/unsorted_segment_sum_fp16.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp16_grad/unsorted_segment_sum.h" +#include "nnacl_c/fp16_grad/unsorted_segment_sum.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc index a3b712e2..10953861 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.cc @@ -19,8 +19,8 @@ #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "schema/model_generated.h" -#include "nnacl/fp32/adder_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/adder_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h index 366e930a..b2b1a7e8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/adder_fp32.h @@ -20,7 +20,7 @@ #ifdef ENABLE_NNACL_KERNEL_LIB #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/fp32/convolution_fp32.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc index 8f2d96b6..a183abc9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.cc @@ -19,8 +19,8 @@ #include #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/fp32/splice_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/splice_fp32.h" #include "src/common/utils.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h index 4cf296fe..917c0d8f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/affine_fp32.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/affine_parameter.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/affine_parameter.h" +#include "nnacl_c/splice_parameter.h" namespace mindspore::kernel { constexpr auto kAffineMinInputNum = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h index 04f8066b..80585bf0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/all_gather_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/all_gather_parameter.h" +#include "nnacl_c/all_gather_parameter.h" namespace mindspore::kernel { class AllGatherCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc index 9794faf0..bb290ccc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/fp32/arithmetic_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h index 90d158e4..76360efd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/broadcast_to_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/broadcast_to.h" +#include "nnacl_c/base/broadcast_to.h" namespace mindspore::kernel { class BroadcastToCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h index e10cfacd..c90b2156 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cast_fp32.h @@ -20,8 +20,8 @@ #include "include/errorcode.h" #include "src/litert/lite_kernel.h" #include "src/tensor.h" -#include "nnacl/op_base.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::kernel { class CastCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h index 13482721..15efa5b8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h @@ -21,13 +21,13 @@ #include #include "src/litert/lite_kernel.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/litert/kernel/cpu/base/layout_transform.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::kernel { class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc index ea46a415..9d26b96c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.cc @@ -25,8 +25,8 @@ #include "src/litert/kernel/cpu/base/group_convolution_creator.h" #include "src/litert/kernel/cpu/fp32/group_convolution_fp32.h" #include "src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" #include "schema/model_generated.h" #include "include/errorcode.h" #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h index 6fb53eed..3fcbf754 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_delegate_fp32.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc index f4200f54..bacd4d52 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h index e1866179..fa3ed605 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_3x3_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc index 510799e0..3ea5ac11 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #include "include/errorcode.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp32/conv_depthwise_avx_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h index 6179c437..8bdef907 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h index 9840e704..ae08a9de 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_indirect_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h index e4d4bb46..ecc94cd9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h index e959fe45..d6b7ca98 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_depthwise_slidewindow_x86_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWCPUKernelX86 : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc index b813ef80..6ce67127 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.cc @@ -18,11 +18,11 @@ #include "src/litert/kernel/cpu/fp32/convolution_fp32.h" #include "src/litert/pack_weight_manager.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h index 0c272991..5ca88340 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_fp32.h @@ -20,7 +20,7 @@ #ifdef ENABLE_NNACL_KERNEL_LIB #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc index ed7d31fb..de407f71 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc index d892b94b..0b7ba386 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_avx512_fp32.h" -#include "nnacl/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_NULL_PTR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc index 33a3368f..aaf7f597 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_im2col_avx_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc index 98085473..9484973a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.cc @@ -17,11 +17,11 @@ #include "src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h" #include "src/litert/pack_weight_manager.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h index b7d80d43..70d00e87 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_base_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc index 409a1ee7..908af5b0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc @@ -34,7 +34,7 @@ #if defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_im2col_arm64_fp32.h" #endif -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" namespace mindspore::kernel { LiteKernel *CreateConvolutionIm2ColCPUKernel(OpParameter *parameter, const std::vector &inputs, diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h index c97e62b6..e1b22015 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc index 32628a88..11138879 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_arm64_fp32.h" -#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" namespace mindspore::kernel { void ConvolutionSWARM64CPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc index e8f3b62d..54dcd0a2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.cc @@ -15,8 +15,8 @@ */ #ifdef ENABLE_AVX #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_avx_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" -#include "nnacl/fp32/conv_1x1_x86_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_1x1_x86_fp32.h" namespace mindspore::kernel { void ConvolutionSWAVXCPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc index e46bdb01..4026cf85 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.cc @@ -15,8 +15,8 @@ */ #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h index e00ce68d..6f8f4080 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_slidewindow_fp32.h @@ -18,7 +18,7 @@ #if defined(ENABLE_AVX) || defined(ENABLE_ARM64) #include #include "src/executor/kernel_exec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h index 2349dc9b..071d21d3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_sw_1x1_fp32.h @@ -19,9 +19,9 @@ #include #include "include/errorcode.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc index f8ce6fe5..ee8ea71c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc index 80dd3d9e..1837425a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_avx_fp32.h" -#include "nnacl/fp32/conv_winograd_fp32.h" -#include "nnacl/pack.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc index df4bc441..e77872c1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h" -#include "nnacl/fp32/conv_winograd_fp32.h" -#include "nnacl/pack.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h index 4f6a45fd..0968c500 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_base_fp32.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/winograd_transform.h" -#include "nnacl/base/minimal_filtering_generator.h" -#include "nnacl/fp32/conv_winograd_fp32.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #define CONV_INPUT_UNIT_SIZE 8 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc index bbd7e6d0..c6a2e2e1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc @@ -31,7 +31,7 @@ #if defined(ENABLE_ARM64) #include "src/litert/kernel/cpu/fp32/convolution_winograd_arm64_fp32.h" #endif -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" namespace mindspore::kernel { LiteKernel *CreateConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h index e38cd911..67c69734 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc index 9e5acfe6..969ecc6c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/cumsum_fp32.h" -#include "nnacl/fp32/cumsum_fp32.h" +#include "nnacl_c/fp32/cumsum_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h index 16edda65..d26af0b5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/cumsum_fp32.h @@ -18,7 +18,7 @@ #include #include "include/errorcode.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "src/executor/kernel_exec.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc index 37e7d3ac..1b9d7e6d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/custom_gru_fp32.cc @@ -20,9 +20,9 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/fp32/custom_gru_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/fp32/custom_gru_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h index eae5f0cf..de8298cb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_depthwise_fp32.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class DeconvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h index c76f5ca7..af0579a7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_fp32.h @@ -25,8 +25,8 @@ #include "include/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::kernel { #define DECONV_WINOGRAD_MAX 2000 diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h index 9f4646aa..1aff6a33 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/deconvolution_winograd_fp32.h @@ -24,8 +24,8 @@ #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "schema/model_generated.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_winograd_fp32.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc index a9907583..b29aa3fd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h index 4261c6b1..d641c8c0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/detection_post_process_fp32.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/detection_post_process_base.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h index 58be4aad..6813374c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" namespace mindspore::kernel { class EmbeddingLookupCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc index 4f080ed7..c97df7cc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.cc @@ -18,8 +18,8 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/base/split_base.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h index 7bb4cc4f..903e8fe7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/glu_fp32.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" -#include "nnacl/split_parameter.h" -#include "nnacl/glu_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/glu_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h index 8c5c1fac..de04f902 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/group_convolution_fp32.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc index c6794d0f..7c7a8720 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.cc @@ -18,8 +18,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/gru_fp32.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/gru_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h index 6db5f006..0736f4ff 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/gru_fp32.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRU_FP32_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/gru_parameter.h" +#include "nnacl_c/gru_parameter.h" namespace mindspore::kernel { class GruCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc index 8359695b..83765b80 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.cc @@ -17,8 +17,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/instance_norm_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/instance_norm_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h index 56926cec..1ce9cec6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/instance_norm_fp32.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_INSTANCE_NORM_FP32_H_ #include #include "src/litert/lite_kernel.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc index 49b1b48d..f15019d4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/invert_permutation_fp32.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/invert_permutation_fp32.h" #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" -#include "nnacl/fp32/invert_permutation_fp32.h" -#include "mindspore/ops/kernel/cpu/nnacl/errorcode.h" +#include "nnacl_c/fp32/invert_permutation_fp32.h" +#include "nnacl_c/errorcode.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc index 06a57107..67b2771a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/fp32/l2_norm_fp32.h" #include "include/errorcode.h" -#include "nnacl/fp32/l2_norm_fp32.h" +#include "nnacl_c/fp32/l2_norm_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h index 032a63c4..eb09b3b6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/l2_norm_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc index bd0f0e7d..d5975a8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" #include #include "include/errorcode.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h index 38800c07..2f96c661 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore::kernel { class LstmFp32BaseCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc index 476d5940..97ccf931 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc index 62f9f2b7..317ea2cf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc index 29e8dba3..42cea273 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.cc @@ -17,9 +17,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include #include "include/errorcode.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #if defined(ENABLE_AVX512) #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx512.h" #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h index 262c7da6..1f0f3403 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_MATMUL_FP32_H_ #include -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc index 86e28d2e..55cd42fb 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm32.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_arm32.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32ARM32CPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc index a0aaddff..902c01fc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_arm64.cc @@ -18,9 +18,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_arm64.h" #include #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/pack_fp32_opt.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc index b9bb8781..401bff22 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32AVXCPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc index e3cbbbbc..bd96027f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_avx512.cc @@ -17,10 +17,10 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_avx512.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_avx512_fp32.h" -#include "nnacl/fp32/matmul_avx512_mask_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { namespace { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc index 38b72c39..cd202df3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" #include -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/pack_fp32_opt.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h index cf17528c..f1ac4f11 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/pack_weight_manager.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "include/errorcode.h" #include "src/common/common.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc index 7790fb7b..996e968d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_sse.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32/matmul_fp32_sse.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32_base.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" namespace mindspore::kernel { void MatmulFp32SSECPUKernel::InitGlobalVariable() { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc index 00d63f21..5c31033f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h index 7e6b011b..0cce8151 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/non_max_suppression_fp32.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl_c/non_max_suppression_parameter.h" using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc index 59d6a5d7..5a937d59 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/cast_gather_reduce_fp32.h" -#include "nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h" +#include "nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc index 9e890e8e..76a7b698 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/reduce_concat_fp32.h" -#include "nnacl/fp32/online_fusion/reduce_concat_fp32.h" +#include "nnacl_c/fp32/online_fusion/reduce_concat_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc index b77dea64..d6f00fb9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h" -#include "nnacl/fp32/online_fusion/split_reduce_concat_fp32.h" +#include "nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h" #include #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h index bcaaaabb..376cd7c4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/online_fusion/split_reduce_concat_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::kernel { class SplitReduceConcatFusionCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h index 909d751d..6dd140fc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/reduce_scatter_fp32.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_scatter_parameter.h" +#include "nnacl_c/reduce_scatter_parameter.h" namespace mindspore::kernel { class ReduceScatterCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc index 8f83579f..1bb57069 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h index 1c73c94b..1ad095d4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/relative_position_attention_fp32.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/attention_fp32.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/fp32/attention_fp32.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { // inputs: 0:Q 1:K 2:V 3:P 4:WQ 5:WK 6:WV 7:WP 8:PU 9:PV 10:WO 11:BQ 12:BK 13:BV 14:BO 15:output diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h index f7b36b5c..a38309d6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/resize_fp32.h @@ -19,7 +19,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/resize_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h index a7b28818..5960b370 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/reverse_sequence_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/reverse_sequence_fp32.h" +#include "nnacl_c/fp32/reverse_sequence_fp32.h" namespace mindspore::kernel { class ReverseSequenceCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc index c3412712..cbc26b7e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/roi_pooling_fp32.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h index f0a02905..de6af7b0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/roi_pooling_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/roi_pooling_fp32.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" namespace mindspore::kernel { class ROIPoolingCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h index 1ea041d2..73c99f6f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_batch_fp32.h @@ -18,8 +18,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/space_to_batch_fp32.h" -#include "nnacl/common_func.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/common_func.h" namespace mindspore::kernel { class SpaceToBatchCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc index cea0b285..eb574a07 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include "src/litert/kernel/cpu/fp32/space_to_depth_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/space_to_depth_base.h" +#include "nnacl_c/base/space_to_depth_base.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h index 6f92f7b9..2e46b091 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/space_to_depth_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" namespace mindspore::kernel { class SpaceToDepthCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc index 84c9ef33..2c96dac3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_fill_empty_rows_fp32.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc index 165ae71b..98e637f1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_reshape_fp32.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc index 9d5de564..59491b5b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_segment_sum_fp32.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/common_func.h" +#include "nnacl_c/common_func.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc index e3411c7b..23c2d68d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.cc @@ -17,9 +17,9 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/sparse_to_dense_fp16.h" +#include "nnacl_c/fp16/sparse_to_dense_fp16.h" #endif #include "schema/ops_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h index 6ab288d4..c4d0d2ce 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/sparse_to_dense_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #include "src/litert/kernel/cpu/base/layout_transform.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h index 02c55322..27ebdfa1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/topk_fp32.h @@ -18,9 +18,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/topk_fp16.h" +#include "nnacl_c/fp16/topk_fp16.h" #endif namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc index 238d8543..18233eea 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.cc @@ -16,7 +16,7 @@ */ #include "src/litert/kernel/cpu/fp32/transpose_server_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h index 26a35c7e..3793c3fc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/transpose_server_fp32.h @@ -19,7 +19,7 @@ #ifdef BFC_MEMORY #include #include "src/litert/kernel/cpu/base/transpose_base.h" -#include "nnacl/fp32/transpose_server_fp32.h" +#include "nnacl_c/fp32/transpose_server_fp32.h" namespace mindspore::kernel { class TransposeServerCPUKernel : public TransposeBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h index f6cada1d..790d1e29 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/uniform_real_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" namespace mindspore::kernel { class UniformRealCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h index 21212a74..2e11ac41 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/unstack_fp32.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/base/unstack_base.h" +#include "nnacl_c/base/unstack_base.h" namespace mindspore::kernel { class UnstackCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc index e442c216..15f6ac92 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/fp32_grad/activation_grad.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h index a5be9df8..1e0e79cc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/activation_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::kernel { class ActivationGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc index de17bc03..1a783e93 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.cc @@ -20,8 +20,8 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/adam_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/adam_fp32.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h index dab71ae5..dfcc8835 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam.h @@ -19,7 +19,7 @@ #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kAdamLrIndex = 5; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc index 1b3d578e..5c30fb8d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/adam_weight_decay.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/adam_fp32.h" +#include "nnacl_c/fp32/adam_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h index 2fd6ef78..d6f11298 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/apply_momentum.h @@ -19,7 +19,7 @@ #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kApplyMomentumLrIndex = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc index b74231e2..309d584e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.cc @@ -17,10 +17,10 @@ #include "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/reduce_grad.h" -#include "nnacl/fp32_grad/arithmetic_grad.h" +#include "nnacl_c/fp32_grad/reduce_grad.h" +#include "nnacl_c/fp32_grad/arithmetic_grad.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h index 92bcb2d8..8402685d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "schema/model_generated.h" using mindspore::schema::PrimitiveType_AddGrad; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc index 725bddd2..24044b8f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/arithmetic_self_grad.cc @@ -18,9 +18,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/fp32_grad/arithmetic_grad.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32_grad/arithmetic_grad.h" +#include "nnacl_c/op_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h index 6d42ac37..9b139115 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/assign.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { class AssignCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h index 056c88f6..a759e086 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bias_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" namespace mindspore::kernel { class BiasGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc index 814977fb..d0376e14 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/binary_cross_entropy.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc index cd50a86b..36b3676c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/binary_cross_entropy_grad.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc index 48550426..ea84f8db 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/bn_grad.cc @@ -23,7 +23,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc index d9eaff5e..0639cec2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution.cc @@ -15,10 +15,10 @@ */ #include "src/litert/kernel/cpu/fp32_grad/convolution.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" #include "include/errorcode.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc index 7bcb222e..35fc941a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc @@ -16,11 +16,11 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/convolution_grad_filter.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/convolution_grad_filter.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc index 68e5968a..515ce2ec 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc @@ -16,10 +16,10 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" -#include "nnacl/fp32_grad/convolution_grad_input.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/fp32_grad/convolution_grad_input.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc index 5b2dae82..94ced4c2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32_grad/pack_ext.h" -#include "nnacl/fp32_grad/gemm.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32_grad/pack_ext.h" +#include "nnacl_c/fp32_grad/gemm.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc index d5df2396..aef8c820 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc index c22b6b9c..f057ba4f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/dropout_grad.cc @@ -16,11 +16,11 @@ #include #include "src/litert/kernel/cpu/fp32_grad/dropout_grad.h" -#include "nnacl/fp32_grad/dropout_grad.h" +#include "nnacl_c/fp32_grad/dropout_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc index ea9d9b0f..4edbed4a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/layernorm_grad.cc @@ -19,9 +19,9 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/layernorm_grad.h" -#include "nnacl/fp32_grad/layernormgrad_parameter.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32_grad/layernorm_grad.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/errorcode.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc index b5f0a3ee..1d783c5d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.cc @@ -20,7 +20,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h index 429c875d..6f1f62b0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_data_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc index a4979a9d..a8edc072 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.cc @@ -20,7 +20,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h index 6eabbc9c..bca70e2d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc index 51a9d2b6..4b064da1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.cc @@ -19,7 +19,7 @@ #include "utils/ms_utils.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h index 844bc011..5db3f80e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/lstm_grad_weight_fp32.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc index ab058d9b..916d7e40 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/neg_grad.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc index 8dd4938a..26751714 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.cc @@ -21,7 +21,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32_grad/nllloss_grad_fp32.h" +#include "nnacl_c/fp32_grad/nllloss_grad_fp32.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h index df443376..5b585728 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/nllloss_grad.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" namespace mindspore::kernel { class NLLLossGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc index 55752ce3..f79c2401 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/fp32_grad/pooling_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/pooling_fp32.h" -#include "nnacl/fp32_grad/pooling_grad.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/fp32_grad/pooling_grad.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h index 8a9e53d8..0e680172 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/pooling_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::kernel { using mindspore::schema::PadMode; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc index 77e5af4c..2247d2bc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.cc @@ -15,11 +15,11 @@ */ #include "src/litert/kernel/cpu/fp32_grad/power_grad.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h index ec7257bc..5c6d2e3c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/power_grad.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/pow_parameter.h" -#include "nnacl/fp32/power_fp32.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/fp32/power_fp32.h" namespace mindspore::kernel { class PowerGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc index 15b035d5..43ec0469 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/resize_grad.cc @@ -16,12 +16,12 @@ #include #include "src/litert/kernel/cpu/fp32_grad/resize_grad.h" -#include "nnacl/fp32_grad/resize_grad.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32_grad/resize_grad.h" +#include "nnacl_c/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h index d43de4a4..2b8758b0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sgd.h @@ -20,7 +20,7 @@ #include #include #include "src/train/optimizer_kernel.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" namespace mindspore::kernel { constexpr int kSgdLrIndex = 2; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h index f4560c53..ea93ea06 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" namespace mindspore::kernel { class SmoothL1LossCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h index d33b758a..2af38da5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/smooth_l1_loss_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" namespace mindspore::kernel { class SmoothL1LossGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc index d01602d5..67551d30 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -16,9 +16,9 @@ #include #include "src/litert/kernel_registry.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/fp32/softmax_fp32.h" -#include "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h" #include "src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h index 51fcca83..416a8152 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h @@ -19,9 +19,9 @@ #include #include "src/train/loss_kernel.h" -#include "nnacl/fp32_grad/softmax_grad.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc index 55ec5dc1..b2d6f003 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/fp32_grad/softmax_grad.h" #include #include -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h index ccde97fb..f87f857c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/softmax_grad.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { class SoftmaxGradCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc index 7fbb4cc4..200c0e14 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.cc @@ -16,9 +16,9 @@ #include #include "src/litert/kernel_registry.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/fp32/softmax_fp32.h" -#include "nnacl/fp32_grad/softmax_grad_utils.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_grad_utils.h" #include "src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h index f24c9a4e..12730a04 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h @@ -19,9 +19,9 @@ #include #include "src/train/loss_kernel.h" -#include "nnacl/fp32_grad/softmax_grad.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc index 91bbd39a..982be44b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.cc @@ -21,7 +21,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/fp32_grad/strided_slice_grad.h" #include "src/common/ops/populate/strided_slice_populate.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h index f34dd20d..10fc1d50 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/strided_slice_grad.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_STRIDED_SLICE_GRAD_H_ #include -#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/fp32_grad/strided_slice_grad.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc index 7c97f07c..681968fc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_grad/unsorted_segment_sum.cc @@ -19,7 +19,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/unsorted_segment_sum_base.h" +#include "nnacl_c/base/unsorted_segment_sum_base.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc index 47373ccd..220901e8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.cc @@ -23,9 +23,9 @@ #ifdef ENABLE_ARM64 #include #endif -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/fp32_sparse/matmul_sparse_x1_fp32.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h index 9475cad9..1e6c94b6 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h +++ b/mindspore-lite/src/litert/kernel/cpu/fp32_sparse/matmul_sparse_fp32.h @@ -18,9 +18,9 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_SPARSE_MATMUL_SPARSE_FP32_H_ #include -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl_c/fp32/transpose_fp32.h" namespace mindspore::kernel { struct SparsityWeight { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc index 5e6c9ccc..30961020 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" #include "src/common/file_utils.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h index c4c8e80b..586184fc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/add_int8.h @@ -20,8 +20,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::kernel { class QuantizedAddCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h index a94a5f00..4eb95ab0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/argminmax_int8.h @@ -17,12 +17,12 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_ARGMINMAX_INT8_H_ #include -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arg_min_max_int8.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arg_min_max_int8.h" +#include "nnacl_c/common_func.h" #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/arg_min_max.h" +#include "nnacl_c/kernel/arg_min_max.h" namespace mindspore::kernel { class ArgMinMaxInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc index 6312577c..ed066209 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/int8/arithmetic_int8.h" #include "src/litert/kernel/cpu/int8/add_int8.h" #include "src/litert/kernel/cpu/int8/mul_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h index f25a928e..3c60c6fd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "schema/model_generated.h" -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" namespace mindspore::kernel { class ArithmeticInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h index 4477391c..8930fe32 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/arithmetic_self_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/arithmetic_self_parameter.h" -#include "nnacl/int8/arithmetic_self_int8.h" +#include "nnacl_c/arithmetic_self_parameter.h" +#include "nnacl_c/int8/arithmetic_self_int8.h" #include "schema/model_generated.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h index 64a8ece1..9cd263d8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batch_to_space_int8.h @@ -18,9 +18,9 @@ #include #include "include/errorcode.h" -#include "nnacl/batch_to_space_parameter.h" -#include "nnacl/base/batch_to_space_base.h" -#include "nnacl/int8/batch_to_space_int8.h" +#include "nnacl_c/batch_to_space_parameter.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/int8/batch_to_space_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc index 39f9ee3d..10ed8e28 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h index 155342ec..3312cd0f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/batchnorm_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/batchnorm_int8.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/int8/batchnorm_int8.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h index 5b0e5a03..89d691d9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/concat_int8.h @@ -19,10 +19,10 @@ #include #include -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/int8/concat_int8.h" #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/kernel/concat.h" +#include "nnacl_c/kernel/concat.h" namespace mindspore::kernel { class ConcatInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h index 0d402dc1..0d009cbf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_1x1_int8.h @@ -22,10 +22,10 @@ #include "include/errorcode.h" #include "schema/model_generated.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/int8/conv1x1_int8.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/int8/conv1x1_int8.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" #include "src/common/utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc index 95c28434..e408695f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/convolution_3x3_int8.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/int8/conv3x3_int8.h" #include "include/errorcode.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h index 29ee39e4..dd9621bd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_3x3_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/winograd_transform.h" +#include "nnacl_c/fp32/winograd_transform.h" #include "src/litert/kernel/cpu/base/convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc index 4671a249..66229b09 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h index 87969621..12a8f3f9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_3x3_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwise3x3Int8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc index 86fbe5a1..47385392 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h index b069a4b9..0164c0f2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc index 00f345c1..d1b772d5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h index 21445a36..61d27ac7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_depthwise_slidewindow_int8.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/common/log_util.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc index 6118bcfa..2d6b7fe2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/convolution_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/int8/conv_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #ifdef ENABLE_ARM64 diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h index f567a705..32d771fd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8.h @@ -21,7 +21,7 @@ #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" #include "src/common/utils.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/int8/conv_int8.h" namespace mindspore::kernel { class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h index c116ca1f..f3f7916b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/convolution_int8_creator.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_CONVOLUTION_INT8_CREATOR_H_ #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc index 89a89c9b..8cc7ef1b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/crop_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/base/crop_base.h" +#include "nnacl_c/base/crop_base.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h index 309af4ac..a5ac9dc2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/crop_int8.h @@ -20,7 +20,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/int8/crop_int8.h" +#include "nnacl_c/int8/crop_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc index fef6ae2d..eea6e79e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h" #include "include/errorcode.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h index 2add61b4..3f5492be 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_depthwise_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/convolution_base.h" -#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" namespace mindspore::kernel { class DeconvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h index 26dbb4a6..55f2bf5d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/deconvolution_int8.h @@ -21,10 +21,10 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/deconv_int8.h" -#include "nnacl/int8/common_func_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/deconv_int8.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #include "src/litert/kernel/cpu/base/layout_transform.h" #include "src/litert/kernel/cpu/base/convolution_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc index 378b4e3a..ff893b42 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.cc @@ -18,7 +18,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h index 291bc071..cf019000 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/depth_to_space_int8.h @@ -19,10 +19,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/base/depth_to_space_base.h" -#include "nnacl/int8/depth_to_space_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/kernel/depth_to_space.h" +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/int8/depth_to_space_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/depth_to_space.h" namespace mindspore::kernel { class DepthToSpaceInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc index 2dd204fc..a4c43b57 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.cc @@ -18,7 +18,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h index f4ff2e8f..569d1591 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/detection_post_process_int8.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/detection_post_process_base.h" -#include "nnacl/fp32/detection_post_process_fp32.h" +#include "nnacl_c/fp32/detection_post_process_fp32.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc index 044c525b..e2c8d868 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/cpu/int8/div_int8.h" #include #include -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h index 15127815..ddebb074 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/div_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/div_int8.h" +#include "nnacl_c/int8/div_int8.h" namespace mindspore::kernel { class DivInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc index 3c2d79bf..e86c84ec 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.cc @@ -15,9 +15,9 @@ */ #include "src/litert/kernel/cpu/int8/dynamic_gather_int8.h" #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/dynamic_gather_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/dynamic_gather_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h index 3de46325..8fe495fe 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_gather_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_DYNAMIC_GATHER_INT8_H_ #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc index acc43c97..daa79ff2 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.cc @@ -20,10 +20,10 @@ #include "src/litert/kernel_registry.h" #include "schema/model_generated.h" #include "include/errorcode.h" -#include "nnacl/int8/dynamic_quant_int8.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/int8/transpose_int8.h" +#include "nnacl_c/int8/dynamic_quant_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/int8/transpose_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h index 023f1fab..137e3d0f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/dynamic_quant.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/dynamic_quant_parameter.h" +#include "nnacl_c/dynamic_quant_parameter.h" namespace mindspore::kernel { class DynamicQuantCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc index a9e30b1e..40827d17 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.cc @@ -21,7 +21,7 @@ #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/gatherNd_int8.h" +#include "nnacl_c/int8/gatherNd_int8.h" using mindspore::kernel::KERNEL_ARCH; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h index d9f7f74e..0d16b7dd 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gatherNd_int8.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_GATHERND_INT8_H_ #include -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc index 7f7f815c..3096adc9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.cc @@ -16,9 +16,9 @@ #include "src/litert/kernel/cpu/int8/gather_int8.h" #include #include "src/litert/kernel/cpu/int8/dynamic_gather_int8.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/gather_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/int8/quantize.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h index 9b50d763..8f72bcd9 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/gather_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_GATHER_INT8_H_ #include -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h index 58d9cf35..c3d3bd9f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/group_convolution_int8.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/kernel/cpu/base/group_convolution_base.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc index bff0702f..6a3c9d1e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/hswish_int8.h" #include -#include "nnacl/int8/hswish_int8.h" +#include "nnacl_c/int8/hswish_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h index b43cd24a..e5448aec 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/hswish_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/hswish_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/hswish_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class HswishInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h index 74c7e6e8..4d50c56e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/l2_norm_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/fp32/l2_norm_fp32.h" -#include "nnacl/int8/l2_norm_int8.h" +#include "nnacl_c/int8/l2_norm_int8.h" namespace mindspore::kernel { class L2NormInt8CPUKernel : public L2NormCPUKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h index 1c37e1f5..cebae25a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/layer_norm_int8.h @@ -18,10 +18,10 @@ #include #include -#include "nnacl/int8/layer_norm_int8.h" +#include "nnacl_c/int8/layer_norm_int8.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/layer_norm.h" +#include "nnacl_c/kernel/layer_norm.h" namespace mindspore::kernel { class LayerNormInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h index 5b65d2d9..1cb72ed7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/leaky_relu_int8.h @@ -20,8 +20,8 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/int8/leaky_relu_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/int8/leaky_relu_int8.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h index 22402e5e..06a5bf3d 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_base_int8.h @@ -20,11 +20,11 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/common_func_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/matmul_int8.h" namespace mindspore::kernel { class MatmulBaseInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc index bab1f730..ce95c450 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc @@ -15,7 +15,7 @@ */ #include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h" -#include "nnacl/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" using mindspore::lite::kCHWDimNumber; using mindspore::lite::kHWDimNumber; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h index 858affc8..42e0da55 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h @@ -21,10 +21,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/common_func_int8.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/common_func_int8.h" #include "src/common/common.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc index 64c0d705..1c8ba263 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/matmul_dynamic_int8.h" #include "src/litert/kernel/cpu/int8/opt_op_handler.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc index 611a02c1..b2a6ef1b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.h" #include -#include "nnacl/int8/dynamic_matmul_int8.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc index 85428f6f..fec4a805 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/int8/matmul_int8.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_int8.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_sdot_int8.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/common_func.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h index 0eb6ca65..55bc989e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/matmul_int8.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_INT8_H_ #include -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/int8/matmul_base_int8.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h index 6464ea7c..0562a727 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/mul_int8.h @@ -20,9 +20,9 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/mul_parameter.h" -#include "nnacl/int8/mul_int8.h" -#include "nnacl/int8/arithmetic_int8.h" +#include "nnacl_c/mul_parameter.h" +#include "nnacl_c/int8/mul_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" namespace mindspore::kernel { class MulInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc index 5718c898..229d4506 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/opt_op_handler.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h index 91691b1b..128147bf 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/opt_op_handler.h @@ -18,7 +18,7 @@ #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc index 56fcec7b..27c1f8f5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.cc @@ -18,7 +18,7 @@ #include #include #include "src/litert/kernel_registry.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h index 15774096..418d3ded 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pad_int8.h @@ -20,10 +20,10 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/errorcode.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/int8/pad_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class PadInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc index 662871fd..560a8561 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/pooling_int8.h" #include -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h index ed8db60e..4c2f3dfa 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/pooling_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc index 4330b32b..d7cffb45 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/power_int8.h" #include -#include "nnacl/int8/power_int8.h" +#include "nnacl_c/int8/power_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h index 5f2bfa6c..e5e4532a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/power_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore::kernel { class PowerInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc index 008bbe30..0ba3d0f8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.cc @@ -18,8 +18,8 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/pack.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h index 149d2090..a1dc3190 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reduce_int8.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/int8/reduce_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/int8/reduce_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel/cpu/base/reduce_base.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h index 504d964a..bb9abb53 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/relux_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/int8/relux_int8.h" namespace mindspore::kernel { constexpr size_t kRelu6Min = 0; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc index a91c5bd3..1df86f3c 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/reshape_int8.h" #include -#include "nnacl/int8/reshape_int8.h" +#include "nnacl_c/int8/reshape_int8.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h index dff07330..91f6e9ba 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/reshape_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc index 53e39d30..aac5c217 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.cc @@ -18,7 +18,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h index 69d76b80..3e38c803 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/resize_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/cpu/base/resize_base.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_Resize; using mindspore::schema::ResizeMethod; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h index e4cdc07e..9c569e9e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/scale_int8.h @@ -20,10 +20,10 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/scale_int8.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/scale_int8.h" namespace mindspore::kernel { class ScaleInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc index 00a3212a..ca4651ce 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/cpu/int8/sigmoid_int8.h" #include #include -#include "nnacl/int8/sigmoid_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/sigmoid_int8.h" +#include "nnacl_c/int8/quantize.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h index 1f383ae6..68fcd2cc 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sigmoid_int8.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/sigmoid_int8.h" +#include "nnacl_c/int8/sigmoid_int8.h" namespace mindspore::kernel { class SigmoidInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc index 9708de00..ab71d4b5 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.cc @@ -17,9 +17,9 @@ #include "src/litert/kernel/cpu/int8/slice_int8.h" #include #include "src/litert/kernel_registry.h" -#include "nnacl/int8/slice_int8.h" +#include "nnacl_c/int8/slice_int8.h" #include "include/errorcode.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h index 2b2745ce..3ebc7536 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/slice_int8.h @@ -19,9 +19,9 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/slice_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" namespace mindspore::kernel { class SliceInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc index 0d418eb9..34ea1f6f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/cpu/int8/softmax_int8.h" #include -#include "nnacl/int8/softmax_int8.h" +#include "nnacl_c/int8/softmax_int8.h" #include "schema/model_generated.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h index 61395c67..f6c0d0ac 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/softmax_int8.h @@ -19,8 +19,8 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::kernel { class SoftmaxInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc index a478edc6..0cb03a44 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/space_to_batch_int8.cc @@ -15,8 +15,8 @@ */ #include "src/litert/kernel/cpu/int8/space_to_batch_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/space_to_batch_fp32.h" -#include "nnacl/int8/space_to_batch_int8.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/int8/space_to_batch_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc index a24bdd3a..87f8ae0a 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/split_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/split_int8.h" #include -#include "nnacl/split_parameter.h" -#include "nnacl/int8/split_int8.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/int8/split_int8.h" #include "include/errorcode.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h index 0b066038..28380f78 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/squeeze_int8.h @@ -20,8 +20,8 @@ #include #include "include/errorcode.h" #include "src/litert/lite_kernel.h" -#include "nnacl/int8/squeeze_int8.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/int8/squeeze_int8.h" +#include "nnacl_c/squeeze_parameter.h" using mindspore::lite::InnerContext; namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h index af3ab66f..77ecd6c8 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/sub_int8.h @@ -19,10 +19,10 @@ #include #include #include -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/lite_kernel.h" -#include "nnacl/int8/sub_int8.h" +#include "nnacl_c/int8/sub_int8.h" namespace mindspore::kernel { class SubInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h index d202a6af..f43454a0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/tanh_int8.h @@ -21,8 +21,8 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/tanh_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/tanh_int8.h" +#include "nnacl_c/int8/quantize.h" #include "include/errorcode.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h index 3f319f9b..2250ef63 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/topk_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/topk_int8.h" +#include "nnacl_c/int8/topk_int8.h" namespace mindspore::kernel { class TopKInt8CPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc index 40d52b5e..6a92baab 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/transpose_int8.cc @@ -16,8 +16,8 @@ #include "src/litert/kernel/cpu/int8/transpose_int8.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/transpose_int8.h" -#include "nnacl/int8/pack_int8.h" +#include "nnacl_c/int8/transpose_int8.h" +#include "nnacl_c/int8/pack_int8.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc index e539c2ba..e490ab0b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc +++ b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "nnacl/int8/unsqueeze_int8.h" +#include "nnacl_c/int8/unsqueeze_int8.h" #include "src/litert/kernel/cpu/int8/unsqueeze_int8.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h index daa3e5e4..52996b42 100644 --- a/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h +++ b/mindspore-lite/src/litert/kernel/cpu/int8/unsqueeze_int8.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/int8/unsqueeze_int8.h" +#include "nnacl_c/int8/unsqueeze_int8.h" #include "src/litert/kernel/cpu/base/layout_transform.h" using mindspore::lite::InnerContext; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc index 0ae492b0..e1f7671e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_batchnorm.cc @@ -17,7 +17,7 @@ #include "nnacl/nnacl_batchnorm.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" using mindspore::schema::PrimitiveType_BatchNorm; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc index 86eade3c..17f7fb97 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_convolution.cc @@ -18,8 +18,8 @@ #include "nnacl/cxx_utils.h" #include "src/litert/pack_weight_manager.h" #include "nnacl/nnacl_manager.h" -#include "nnacl/kernel/convolution_base.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc index e604423b..2d9f635e 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_fused_batch_norm.cc @@ -17,8 +17,8 @@ #include "nnacl/nnacl_fused_batch_norm.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/fp32/batchnorm_fp32.h" -#include "nnacl/kernel/fused_batch_norm.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/kernel/fused_batch_norm.h" using mindspore::schema::PrimitiveType_FusedBatchNorm; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc index e970b97c..e467a351 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc @@ -18,7 +18,7 @@ #include "nnacl/cxx_utils.h" #include "src/tensor.h" #include "include/errorcode.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h index c03a22aa..23203686 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_NNACL_KERNEL_H_ #include -#include "nnacl/kernel.h" +#include "nnacl_c/kernel.h" #include "src/executor/kernel_exec.h" #include "src/litert/lite_kernel.h" diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc index 187c95ac..7f169ec1 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc @@ -17,7 +17,7 @@ #include "nnacl/nnacl_matmul.h" #include "nnacl/nnacl_manager.h" #include "include/errorcode.h" -#include "nnacl/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_base.h" #include "nnacl/cxx_utils.h" #include "src/litert/pack_weight_manager.h" #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h index acd4d31b..7e10fc20 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.h @@ -19,7 +19,7 @@ #include #include "nnacl/nnacl_kernel.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::nnacl { class MatmulKernel : public NNACLKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc index 03c9efd2..a1d2f3d3 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_reduce.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/reduce.h" +#include "nnacl_c/kernel/reduce.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc index 8370f83e..fb86fdc4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_strided_slice.cc @@ -19,7 +19,7 @@ #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/kernel/strided_slice.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt new file mode 100644 index 00000000..374b8bb8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/CMakeLists.txt @@ -0,0 +1,293 @@ +project(nnacl) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${NNACL_DIR}/..) + +set(NNACL_SIMD_DIR ${CMAKE_BINARY_DIR}/src/nnacl_c) +include_directories(${NNACL_SIMD_DIR}/..) +file(GLOB SIMD_CONFIG_HEADER + ${NNACL_DIR}/base/*_simd.h.in + ${NNACL_DIR}/fp32/*_simd.h.in + ${NNACL_DIR}/fp32/online_fusion/*_simd.h.in + ${NNACL_DIR}/fp32_grad/*_simd.h.in +) +function(generate_simd_header_code) + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + string(REGEX REPLACE "_simd.h.in" "" tmp3 ${tmp1}) + string(TOLOWER ${tmp3} OP_NAME_LOWER) + string(TOUPPER ${tmp3} OP_NAME_UPPER) + configure_file(${NNACL_DIR}/op_simd_header_file.h.in ${NNACL_SIMD_DIR}/${tmp3}_simd.h) + endforeach() +endfunction() + +function(generate_simd_code SIMD_INSTRUCTION SIMD_BLOCK_NUM SIMD_TARGET) + string(TOLOWER ${SIMD_INSTRUCTION} SIMD_INSTRUCTION_LOWER) + set(SIMD_DEFINE "#define MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_UNDEF "#undef MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_DEF_INSTRUCTION "#define MS_SIMD_INSTRUCTION MS_SIMD_${SIMD_INSTRUCTION}_INSTRUCTION") + set(SIMD_UNDEF_INSTRUCTION "#undef MS_SIMD_INSTRUCTION") + set(SIMD_DEF_BLOCK_NUM "#define BLOCK_NUM ${SIMD_BLOCK_NUM}") + set(SIMD_UNDEF_BLOCK_NUM "#undef BLOCK_NUM") + if(SIMD_INSTRUCTION_LOWER STREQUAL "neon") + set(SIMD_TARGET_BEGIN "") + set(SIMD_TARGET_END "") + else() + set(SIMD_TARGET_BEGIN "#pragma GCC push_options\n#pragma GCC target(${SIMD_TARGET})") + set(SIMD_TARGET_END "#pragma GCC pop_options") + endif() + + set(SIMD_INSTRUCTION_BEGIN "${SIMD_TARGET_BEGIN}\n${SIMD_DEF_INSTRUCTION}\n${SIMD_DEF_BLOCK_NUM}\n${SIMD_DEFINE}") + set(SIMD_INSTRUCTION_END "${SIMD_UNDEF_INSTRUCTION}\n${SIMD_UNDEF_BLOCK_NUM}\n${SIMD_TARGET_END}\n${SIMD_UNDEF}") + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + configure_file(${simd_config_file} ${NNACL_SIMD_DIR}/${SIMD_INSTRUCTION_LOWER}/${tmp2}) + endforeach() +endfunction() +generate_simd_code(NEON 4 \"\") +generate_simd_code(SSE 4 \"sse4.1\") +generate_simd_code(AVX 8 "\"avx\", \"avx2\"") +generate_simd_code(AVX512 16 \"avx512f\") +generate_simd_header_code() + +if(ENABLE_CPU AND NOT MSVC) + set(CMAKE_C_FLAGS "-Wno-attributes ${CMAKE_C_FLAGS}") +endif() + +if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64 OR PLATFORM_MCU) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math -Wno-shorten-64-to-32") + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math") + endif() + if(TARGET_OHOS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-inline-asm") + endif() +elseif(NOT MSVC) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \ + -fdata-sections") + endif() +endif() + +if(NOT MSVC) + if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") + endif() + if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma") + endif() + if(WIN32) + if("${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fno-asynchronous-unwind-tables") + endif() + endif() +endif() + + +########################### files ########################### +file(GLOB KERNEL_SRC + ${NNACL_DIR}/*.c + ${NNACL_DIR}/fp32/*.c + ${NNACL_DIR}/infer/*.c + ${NNACL_DIR}/base/*.c + ${NNACL_DIR}/fp32_grad/*.c + ${NNACL_DIR}/kernel/*.c + ${NNACL_DIR}/experimental/*.c + ${NNACL_DIR}/fp32/online_fusion/*.c +) + +set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c + ${NNACL_DIR}/fp32/matmul_avx512_mask_fp32.c + ${NNACL_DIR}/fp32/conv_im2col_avx512_fp32.c +) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX512_FILE}) + +set(KERNEL_AVX_FILE ${NNACL_DIR}/fp32/conv_sw_avx_fp32.c + ${NNACL_DIR}/fp32/conv_1x1_avx_fp32.c + ${NNACL_DIR}/fp32/matmul_avx_fp32.c + ${NNACL_DIR}/fp32/conv_depthwise_avx_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX_FILE}) + +set(KERNEL_ARM64_FILE ${NNACL_DIR}/fp32/conv_sw_arm64_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_ARM64_FILE}) + +if(NOT MSLITE_ENABLE_RUNTIME_PASS) + list(REMOVE_ITEM KERNEL_SRC ${NNACL_DIR}/infer/shape_fusion_infer.c) +endif() +if((NOT DEFINED MSLITE_ENABLE_INT8) OR MSLITE_ENABLE_INT8) + file(GLOB KERNEL_SRC_INT8 + ${NNACL_DIR}/int8/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INT8} + ) +else() + set(KERNEL_SRC + ${KERNEL_SRC} + ${NNACL_DIR}/int8/pack_int8.c + ${NNACL_DIR}/int8/quantize.c + ) +endif() + +if(MSLITE_ENABLE_SPARSE_COMPUTE) + file(GLOB KERNEL_SRC_SPARSE + ${NNACL_DIR}/fp32_sparse/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_SPARSE} + ) +endif() + +if(MSLITE_ENABLE_STRING_KERNEL) + file(GLOB KERNEL_SRC_INFER_STRING + ${NNACL_DIR}/infer/string/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_STRING} + ) +endif() +if(MSLITE_ENABLE_CONTROLFLOW) + file(GLOB KERNEL_SRC_INFER_CONTROL_TENSORLIST + ${NNACL_DIR}/infer/control/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_CONTROL_TENSORLIST} + ) +endif() +if(PLATFORM_ARM64) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${KERNEL_ARM64_FILE}) +endif() + +if(PLATFORM_ARM32) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm32/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) +endif() + +if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_SSE_SRC ${NNACL_DIR}/intrinsics/sse/*.c) + set_property(SOURCE ${ASSEMBLY_SSE_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_SSE_SRC + ${ASSEMBLY_SSE_SRC} + ${KERNEL_SSE_FILE}) + set_source_files_properties(${MS_X86_SSE_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_SSE_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_AVX_SRC + ${NNACL_DIR}/intrinsics/avx/*.c + ${NNACL_DIR}/assembly/avx/*.S + ${NNACL_DIR}/intrinsics/ms_simd_cpu_info.c) + set_property(SOURCE ${ASSEMBLY_AVX_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX_SRC + ${ASSEMBLY_AVX_SRC} + ${KERNEL_AVX_FILE}) + set_source_files_properties(${MS_X86_AVX_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx512") + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + file(GLOB HPC_SRC ${NNACL_DIR}/experimental/HPC-generator/gemm_avx512/*.c + ${NNACL_DIR}/experimental/HPC-generator/gemm_mask_avx512/*.c) + + set_property(SOURCE ${HPC_SRC} PROPERTY LANGUAGE C) + endif() + + file(GLOB ASSEMBLY_AVX512_SRC ${NNACL_DIR}/assembly/avx512/*.S) + set_property(SOURCE ${ASSEMBLY_AVX512_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX512_SRC + ${HPC_SRC} + ${ASSEMBLY_AVX512_SRC} + ${KERNEL_AVX512_FILE}) + + set_source_files_properties(${MS_X86_AVX512_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX512_SRC}) +endif() + +if(APPLE) + set_source_files_properties(${ASSEMBLY_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() + +########################### build nnacl library ######################## +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() + +if(PLATFORM_ARM) + set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c) + set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") +endif() + +add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_DEBUG) +endif() + +if(ENABLE_CPU) + if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_ARM ENABLE_ARM64 ENABLE_NEON) + target_compile_options(nnacl_mid PRIVATE -ffast-math -flax-vector-conversions) + elseif("${X86_64_SIMD}" STREQUAL "sse") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE) + elseif("${X86_64_SIMD}" STREQUAL "avx") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX) + elseif("${X86_64_SIMD}" STREQUAL "avx512") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512) + endif() + if(NOT MSVC) + target_compile_options(nnacl_mid PRIVATE -fPIC -fstack-protector-all) + add_library(nnacl SHARED $) + else() + add_library(nnacl STATIC $) + endif() + if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") + if(NOT CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(nnacl PRIVATE -Wl,-z,relro,-z,now,-z,noexecstack) + target_link_libraries(nnacl PRIVATE m) + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + target_link_options(nnacl PRIVATE -s) + endif() + endif() +endif() + +set(nnacl_static_obj $) +########################### arm fp16 build optimize library ######################## +if(PLATFORM_ARM) + add_subdirectory(${NNACL_DIR}/optimize) + if(TARGET nnacl_optimize_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() + if(TARGET nnacl_fp16_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() +endif() +if(NOT ${CMAKE_GENERATOR} MATCHES "Ninja") + add_library(nnacl_static STATIC ${nnacl_static_obj}) + set_target_properties(nnacl_static PROPERTIES OUTPUT_NAME "nnacl") + set_target_properties(nnacl_static PROPERTIES CLEAN_DIRECT_OUTPUT 1) +endif() diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS new file mode 100644 index 00000000..35027888 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/OWNERS @@ -0,0 +1,11 @@ +approvers: +- jjfeing +- YeFeng_24 +- fatmouse007fatmouse007 +- xu_anyue + +reviewers: +- liuf9 + +options: + no_parent_owners: true diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md new file mode 100644 index 00000000..c756a7db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/README.md @@ -0,0 +1 @@ +NNACL(neural network accelerated computing library) is a high performance library of neural network inference computing kernels for ARM. diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h new file mode 100644 index 00000000..7a9af324 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/activation_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ACTIVATION_PARAMETER_H_ +#define NNACL_ACTIVATION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_; + float min_val_; + float max_val_; + bool approximate_; +} ActivationParameter; + +#endif // NNACL_ACTIVATION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h new file mode 100644 index 00000000..d4d30934 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/affine_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AFFINE_PARAMETER_H_ +#define NNACL_AFFINE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" +typedef struct AffineParameter { + OpParameter op_parameter_; + // parameters from splice op + int context_size_; + int *context_; + int output_dim_; + // parameters from activation op + int activation_type_; + // parameters from matmul op + MatMulParameter *matmul_parameter_; +} AffineParameter; +#endif // NNACL_AFFINE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h new file mode 100644 index 00000000..3ddbde86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/all_gather_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ALL_GATHER_PARAMETER_H_ +#define NNACL_ALL_GATHER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct AllGatherParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + + // other parameter + int rank_size_; +} AllGatherParameter; +#endif // NNACL_ALL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h new file mode 100644 index 00000000..0fdad178 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arg_min_max_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARG_MIN_MAX_PARAMETER_H_ +#define NNACL_ARG_MIN_MAX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ArgMinMaxParameter { + OpParameter op_parameter_; + int32_t axis_; + int32_t topk_; + bool keep_dims_; + bool out_value_; +} ArgMinMaxParameter; + +#endif // NNACL_ARG_MIN_MAX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h new file mode 100644 index 00000000..adc4643a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARTITHMETIC_PARAMETER_H_ +#define NNACL_ARTITHMETIC_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +#define ARITHMETIC_SUPPORT_DIMS_NUM 10 + +typedef struct ArithmeticParameter { + OpParameter op_parameter_; + bool broadcasting_; + size_t ndim_; + int activation_type_; + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num1_; + + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int eltwise_mode_; // eltwise need +} ArithmeticParameter; + +#endif // NNACL_ARTITHMETIC_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h new file mode 100644 index 00000000..611a3064 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/arithmetic_self_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARITHMETIC_SELF_PARAMETER_H_ +#define NNACL_ARITHMETIC_SELF_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" + +// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. +typedef struct ArithmeticSelfParameter { + OpParameter op_parameter_; + ArithSelfQuantArg quant_arg_; +} ArithmeticSelfParameter; + +#endif // NNACL_ARITHMETIC_SELF_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S new file mode 100644 index 00000000..eadcf972 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDw3x3Int8BorderPixel.S @@ -0,0 +1,128 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, +// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max +// size_t per_channel) { + +// todo: support per channel +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift +// r14: acc_min, r15: acc_max +asm_function ConvDw3x3Int8BorderPixel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + ldrb r10, [sp, #20] // in_zp + vdup.8 d18, r10 + ldr r10, [sp, #24] // out_zp + vdup.32 q15, r10 + ldr r10, [sp, #28] // out_multiplier + vdup.32 q14, r10 + ldr r10, [sp, #32] // left_shift + vdup.32 q13, r10 + ldr r10, [sp, #36] // right_shift + vdup.32 q12, r10 + ldr r10, [sp, #40] // acc_min + vdup.32 q11, r10 + ldr r10, [sp, #44] // acc_max + vdup.32 q10, r10 + + mov r4, #2 + mul lr, r8, r4 + + LoopC: + mov r9, r1 + mov r10, r2 + ldr r4, [sp] + + vld1.32 {q3}, [r3]! + vld1.32 {q4}, [r3]! + LoopH: + mov r11, r9 + mov r12, r10 + ldr r5, [sp, #4] + LoopW: + vld1.8 {d0}, [r11], r7 + vld1.16 {d2, d3}, [r12], lr // weight + vsubl.s8 q2, d0, d18 // -zp + + vmlal.s16 q3, d4, d2 + vmlal.s16 q4, d5, d3 + + subs r5, r5, #1 + bne LoopW + subs r4, r4, #1 + add r9, r9, r6 + mov r11, #3 + mul r5, lr, r11 + add r10, r10, r5 + bne LoopH + + vshl.s32 q3, q3, q13 + vqrdmulh.s32 q3, q3, q14 + vand q5, q3, q12 + vshr.s32 q5, q5, #31 + vqadd.s32 q3, q3, q5 + vrshl.s32 q3, q3, q12 + vadd.i32 q3, q3, q15 + vmax.s32 q3, q3, q11 + vmin.s32 q3, q3, q10 + vqmovn.s32 d14, q3 + + vshl.s32 q4, q4, q13 + vqrdmulh.s32 q4, q4, q14 + vand q6, q4, q12 + vshr.s32 q6, q6, #31 + vqadd.s32 q4, q4, q6 + vrshl.s32 q4, q4, q12 + vadd.i32 q4, q4, q15 + vmax.s32 q4, q4, q11 + vmin.s32 q4, q4, q10 + vqmovn.s32 d15, q4 + vqmovn.s16 d16, q7 + + vst1.8 {d16}, [r0]! + add r1, r1, #8 + add r2, r2, #16 + + sub r8, r8, #8 + cmp r8, #8 + bge LoopC + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S new file mode 100644 index 00000000..5da6cdd4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Border.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: kernel_w, r9: relu, r10: relu6 +asm_function ConvDwFp32Border + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] // height + ldr r5, [sp, #4] // width + ldr r6, [sp, #8] // in_kh_step + ldr r7, [sp, #12] // in_kw_step + ldr r8, [sp, #16] // kernel_w + ldr r9, [sp, #20] // relu + ldr r10, [sp, #24] // relu6 + + vld1.32 {q0}, [r3] // bias + vmov.i32 q1, #6 // relu6 + vcvt.f32.s32 q1, q1 + veor q2, q2, q2 // relu + + LoopH: + mov r11, r1 + mov r12, r2 + mov r14, r5 + LoopW: + vld1.32 {q3}, [r11], r7 + vld1.32 {q4}, [r12]! + vmla.f32 q0, q3, q4 + subs r14, r14, #1 + bne LoopW + subs r4, r4, #1 + add r1, r1, r6 + add r2, r2, r8 + bne LoopH + + cmp r10, #0 + bne Relu6 + cmp r9, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q1 + Relu: + vmax.f32 q0, q0, q2 + Write: + vst1.32 {q0}, [r0] + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S new file mode 100644 index 00000000..9935418b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Center.S @@ -0,0 +1,176 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// r0: dst, r1: src, r2: weight, r3: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step,#36: in_kw_step +// #40: relu, #44: relu6 +asm_function ConvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r4, [sp] // height + + vld1.32 {q13}, [r3] + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + LoopH: + ldr r1, [sp, #-44] // src_w, src_h = src + ldr r5, [sp, #4] // width + ldr r0, [sp, #-48] // dst_w, dst_h = dst + cmp r5, #4 + blt LoopW + LoopW4: + ldr r11, [sp, #28] // in_sw_step + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 + vmov q1, q13 + vmov q2, q13 + vmov q3, q13 + LoopKh4: + ldr r7, [sp, #12] // kernel_w + mov lr, r8 // src_kw, src_kh + LoopKw4: + ldr r12, [sp, #36] //in_kw_step + mov r10, lr + vld1.32 {q12}, [r2]! + vld1.32 {q4}, [r10] + add r10, r10, r11 + vmla.f32 q0, q4, q12 + vld1.32 {q5}, [r10] + add r10, r10, r11 + vmla.f32 q1, q5, q12 + vld1.32 {q6}, [r10] + add r10, r10, r11 + vmla.f32 q2, q6, q12 + vld1.32 {q7}, [r10] + add r10, r10, r11 + vmla.f32 q3, q7, q12 + subs r7, r7, #1 + add lr, lr, r12 + bne LoopKw4 + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh4 + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu64 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu4 + b Write4 + Relu64: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + Relu4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + Write4: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] + add r0, r0, r12 + vst1.32 {q1}, [r0] + add r0, r0, r12 + vst1.32 {q2}, [r0] + add r0, r0, r12 + vst1.32 {q3}, [r0] + add r0, r0, r12 + mov r12, #4 + mul r11, r11, r12 + add r1, r1, r11 // src_w += in_sw_step + sub r5, r5, #4 + cmp r5, #0 + ble LoopWEnd + cmp r5, #4 + bge LoopW + LoopW: + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 // bias + LoopKh: + ldr r7, [sp, #12] // kernel_w + mov r10, r8 // src_kw, src_kh + LoopKw: + ldr r12, [sp, #36] //in_kw_step + vld1.32 {q1}, [r10] + add r10, r10, r12 + vld1.32 {q12}, [r2]! + vmla.f32 q0, q1, q12 + subs r7, r7, #1 + bne LoopKw + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu6 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q14 + Relu: + vmax.f32 q0, q0, q15 + Write: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] // dst_kw += block_channel + add r0, r0, r12 + ldr r12, [sp, #28] // in_sw_step + add r1, r1, r12 // src_w += in_sw_step + subs r5, r5, #1 + bne LoopW + ldr r3, [sp, #16] // out_h_step + ldr r12, [sp, #-48] + add r12, r12, r3 + str r12, [sp, #-48] + + ldr r3, [sp, #24] // in_sh_step + ldr r12, [sp, #-44] // src_h += in_sh_step + add r12, r12, r3 + str r12, [sp, #-44] + + subs r4, r4, #1 // height + bne LoopH +LoopWEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S new file mode 100644 index 00000000..e1b2ff7e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwFp32Row.S @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// voidConvDwFp32Row(float* output_ptr, const float* input_ptr, const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// r0: output_ptr, r1: input_ptr, r2: filter_ptr, r3: num_pixels, +// r4: input_channel, r5: input_step +asm_function ConvDwFp32Row + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r6, r8, r10, r11} + vpush {q4-q7} + add sp, sp, #88 + mov r11, r0 + ldr r4, [sp] + ldr r5, [sp, #4] + mov r6, #4 + mul r5, r5, r6 + cmp r3, #0 + ble End + + LoopNumPixel: + mov r6, r1 // input_ptr + mov r8, r2 // filter_ptr + mov r10, r4 // input_channel + + LoopDepth16In: + cmp r10, #16 + blt L4 + sub r10, r10, #16 + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + cmp r10, #16 + blt LoopDepth16Out + LoopDepth16: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + sub r10, r10, #16 + cmp r10, #16 + bge LoopDepth16 + + LoopDepth16Out: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + L4: + cmp r10, #4 + blt L0 + + LoopDepth4: + vld1.32 {q0}, [r6]! + vld1.32 {q4}, [r8]! + vld1.32 {q8}, [r0]! + vmla.f32 q8, q0, q4 + vst1.32 {q8}, [r11]! + sub r10, r10, #4 + cmp r10, #4 + bge LoopDepth4 + + L0: + cmp r10, #0 + beq Loop16LineEnd + + LoopDepth0: + vld1.32 d0[0], [r6]! + vld1.32 d2[0], [r8]! + vld1.32 d4[0], [r0]! + vmla.f32 s8, s0, s4 + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + bne LoopDepth0 + + Loop16LineEnd: + subs r3, r3, #1 + add r1, r1, r5 + bne LoopNumPixel + + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r6, r8, r10, r11} + bx lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S new file mode 100644 index 00000000..fc41d0fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Center.S @@ -0,0 +1,290 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, +// int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, +// int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, +// int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, +// int32_t *acc_max) +// #-48: dst, #-44: src, #-40: weight, #-36: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step, #36: in_kw_step +// #40: in_zp, #44: out_zp, #48: out_multiplier, #52: left_shift, #56: right_shift, #60:acc_min, #64: acc_max +asm_function ConvDwInt8Center +// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" +// according to https://stackoverflow.com/questions/53625807 +// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway +// clang's rule seems more simple, though there are no subroutine calls here +// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + + ldr lr, [sp, #168] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #204] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #240] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + add sp, sp, #208 + + ldr r1, [sp, #-36] + vld1.32 {q8, q9}, [r1] + ldr r1, [sp, #44] + vld1.32 {q10, q11}, [r1] + ldr r1, [sp, #48] + vld1.32 {q12, q13}, [r1] + ldr r1, [sp, #52] + vld1.32 {q14, q15}, [r1] + + ldr r11, [sp, #28] + ldr r4, [sp] + LoopH: + ldr r1, [sp, #-44] + ldr r0, [sp, #-48] + ldr r5, [sp, #4] + LoopW2: + vmov q4, q8 + vmov q5, q9 + vmov q6, q8 + vmov q7, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW: + mov r8, r9 + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q6, d4, d0 + vmlal.s16 q7, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + vshl.s32 q6, q6, q14 + vshl.s32 q7, q7, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + vqrdmulh.s32 q6, q6, q12 + vqrdmulh.s32 q7, q7, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vand q2, q6, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q6, q6, q2 + vrshl.s32 q6, q6, q0 + + vand q2, q7, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q7, q7, q2 + vrshl.s32 q7, q7, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + vadd.i32 q6, q6, q10 + vadd.i32 q7, q7, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + vmax.s32 q6, q6, q0 + vmax.s32 q7, q7, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + vmin.s32 q6, q6, q0 + vmin.s32 q7, q7, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s32 d2, q6 + vqmovn.s32 d3, q7 + vqmovn.s16 d0, q0 + vqmovn.s16 d1, q1 + + + ldr r12, [sp, #20] + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + add r0, r0, r12 + + mov r8, r0 + vst1.8 {d1[0]}, [r8]! + vst1.8 {d1[1]}, [r8]! + vst1.8 {d1[2]}, [r8]! + vst1.8 {d1[3]}, [r8]! + vst1.8 {d1[4]}, [r8]! + vst1.8 {d1[5]}, [r8]! + vst1.8 {d1[6]}, [r8]! + vst1.8 {d1[7]}, [r8]! + add r0, r0, r12 + + add r1, r1, r11 + add r1, r1, r11 + subs r5, r5, #2 + beq LoopEndW + cmp r5, #2 + bge LoopW2 + + LoopW: + vmov q4, q8 + vmov q5, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH1: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW1: + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r9] + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW1 + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH1 + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s16 d0, q0 + + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + ldr r12, [sp, #20] + add r0, r0, r12 + add r1, r1, r11 + subs r5, r5, #1 + bne LoopW + + LoopEndW: + ldr r12, [sp, #16] + ldr r1, [sp, #-48] + add r1, r1, r12 + str r1, [sp, #-48] + ldr r12, [sp, #24] + ldr r1, [sp, #-44] + add r1, r1, r12 + str r1, [sp, #-44] + subs r4, r4, #1 + bne LoopH + + LoopEndH: + sub sp, sp, #208 + vpop {q0, q1} + vpop {q0, q1} + vpop {q0, q1} + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S new file mode 100644 index 00000000..d74d5e2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4.S @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4 + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + vdup.32 q14, r4 + + ldr r5, [sp, #4] // left_shift + vdup.32 q13, r5 + + ldr r6, [sp, #8] // right_shift + vdup.32 q12, r6 + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vshl.s32 q1, q1, q13 + vqrdmulh.s32 q1, q1, q14 + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 00000000..40bbb8a4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + ldr r5, [sp, #4] // left_shift + ldr r6, [sp, #8] // right_shift + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q1, q1, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q1, q1, q14 + vld1.32 {q12}, [r6]! + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S new file mode 100644 index 00000000..2833b4a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/ConvDwInt8Row.S @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels, +// r4: output_channel, r5: input_step, r6: input_zp, + +asm_function ConvDwInt8Row + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + cmp r3, #0 + beq End + + ldr r4, [sp] // channel + ldr r5, [sp, #4] // input_step + ldr r6, [sp, #8] // input_zp + vdup.8 d30, r6 + + mov r7, r0 + + LoopPixel: + mov r8, r1 // input + mov r10, r2 // weight + mov r11, r4 + + LoopDepth16In: + cmp r11, #16 + blt L8 + sub r11, r11, #16 + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + cmp r11, #16 + blt LoopDepth16Out + LoopDepth16: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + sub r11, r11, #16 + cmp r11, #16 + bge LoopDepth16 + + LoopDepth16Out: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + L8: + cmp r11, #8 + blt L0 + + LoopDepth8: + vld1.8 {d0}, [r8]! + vld1.16 {d2, d3}, [r10]! // weight + + vsubl.s8 q2, d0, d30 // -zp + + vld1.32 {q3}, [r0]! + vmlal.s16 q3, d4, d2 + vst1.32 {q3}, [r7]! + + vld1.32 {q4}, [r0]! + vmlal.s16 q4, d5, d3 + vst1.32 {q4}, [r7]! + + sub r11, r11, #8 + cmp r11, #8 + bge LoopDepth8 + + L0: + cmp r11, #0 + beq LoopDepthEnd + + LoopDepth0: + ldrsb r12, [r8], #1 + ldrsh r9, [r10], #2 + sub r12, r12, r6 + + ldr lr, [r0], #4 + smlabb r12, r12, r9, lr + str r12, [r7], #4 + + subs r11, r11, #1 + bne L0 + + LoopDepthEnd: + add r1, r1, r5 + subs r3, r3, #1 + bne LoopPixel + + End: + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S new file mode 100644 index 00000000..f42a1b82 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwFp32Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.32 {q1}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.32 {q2}, [r2]! + vmla.f32 q0, q1, q2 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S new file mode 100644 index 00000000..e7e6cd15 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwInt8Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.16 {d2}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.16 {d24}, [r2]! + vmlal.s16 q0, d2, d24 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S new file mode 100644 index 00000000..afc0cd30 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/DeconvDwInt8Post.S @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// r0: dst, r1: output_buffer, r2: bias, r3: block_channel, r4: pixel_nums, r5: out_multiplier, +// r6: left_shift, r7: right_shift, r8: out_zp, r9: acc_min, r10: acc_max + +asm_function DeconvDwInt8Post + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8} + add sp, sp, #20 + + vld1.32 {q9}, [r2] + ldr r4, [sp] + ldr r5, [sp, #4] + vdup.32 q14, r5 // out_multiplier + ldr r6, [sp, #8] + vdup.32 q13, r6 // left_shift + ldr r5, [sp, #12] + vdup.32 q12, r5 // right_shift + ldr r6, [sp, #16] + vdup.32 q15, r6 // output_zp + ldr r7, [sp, #20] + vdup.32 q11, r7 // acc_min + ldr r8, [sp, #24] + vdup.32 q10, r8 // acc_max + + LoopCount: + mov r8, r0 + vld1.32 {q0}, [r1]! + vand q0, q0, q9 + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + add r0, r0, r3 + + sub r4, r4, #1 + cmp r4, #1 + bge LoopCount + End: + sub sp, sp, #20 + pop {r4-r8} + bx lr + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S new file mode 100644 index 00000000..a0464c90 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t kszie, size_t ic8, size_t oc4, size_t offset); +// r0: output, r1: input, r2: weight, r3: kszie, r4: ic8, r5: oc4, r6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + // we could also use "vmov.s32 q12, #0" to initialize q12 by 0 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, lr} + + ldr r4, [sp, #28] + ldr r5, [sp, #32] + ldr r6, [sp, #36] + + vpush {q4-q7} + + LoopOc: + + mov r7, r3 + mov r8, r1 + + LoopKsize: + mov r10, r0 + INIT_ZERO + + // load input + vld1.16 {q0, q1}, [r8]! + // load weight + vld1.16 {q4}, [r2]! + vmull.s16 q8, d8, d0[0] + vmull.s16 q9, d8, d2[0] + // load weight + vld1.16 {q5}, [r2]! + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + // load weight + vld1.16 {q6, q7}, [r2]! + vmull.s16 q10, d8, d4[0] + vmull.s16 q11, d8, d6[0] + + subs r12, r4, #1 + beq LoopIcEnd + + LoopIc: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + // load weight + vld1.16 {q4, q5}, [r2]! + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d8, d0[0] + vmlal.s16 q9, d8, d2[0] + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load weight + vld1.16 {q6, q7}, [r2]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + vmlal.s16 q10, d8, d4[0] + vmlal.s16 q11, d8, d6[0] + + subs r12, r12, #1 + bne LoopIc + + LoopIcEnd: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vst1.32 {q8}, [r10], r6 + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vst1.32 {q9}, [r10], r6 + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.s16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vst1.32 {q10}, [r10], r6 + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vst1.32 {q11}, [r10], r6 + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + vst1.32 {q12}, [r10], r6 + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vst1.32 {q13}, [r10], r6 + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + + vst1.32 {q14}, [r10], r6 + vst1.32 {q15}, [r10] + + subs r7, r7, #1 + add r0, r0, #16 + bne LoopKsize + + subs r5, r5, #1 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, r10, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S new file mode 100644 index 00000000..dd459e1f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/IndirectGemmInt8_2x4.S @@ -0,0 +1,306 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after +asm_function IndirectGemmInt8_2x4 + + .macro INIT_BIAS + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + + mov r8, r4 + mov r12, r1 + + LoopKsize: + INIT_BIAS + mov r11, r0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-2 + vld1.8 {q0, q1}, [r12]! + // load weight for oc 1-2 + vld1.8 {q2, q3}, [r2]! + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vpaddl.s16 q8, q6 + vpaddl.s16 q9, q7 + // load weight for oc 3-4 + vld1.8 {q4, q5}, [r2]! + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1 + vld1.8 {q0}, [r12]! + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vld1.8 {q2, q3}, [r2]! + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vld1.8 {q4, q5}, [r2]! + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vld1.8 {q1}, [r12]! + vpadal.s16 q8, q6 + vpadal.s16 q9, q7 + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + + // pairwise add + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + vpadd.i32 d28, d28, d29 + vpadd.i32 d30, d30, d31 + + vpadd.i32 d16, d16, d18 + vpadd.i32 d17, d20, d22 + vpadd.i32 d24, d24, d26 + vpadd.i32 d25, d28, d30 + + // load sum + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSum + ldr r10, [sp, #16] + ldr lr, [sp, #48] + cmp lr, #0 + beq SymSum + ldr lr, [sp, #52] + vld1.32 {d0, d1}, [r10] + add r10, r10, lr + vld1.32 {d2, d3}, [r10] + b AddSum + SymSum: + vld1.32 {d0[], d1[]}, [r10]! + vld1.32 {d2[], d3[]}, [r10]! + AddSum: + vsub.i32 q8, q8, q0 + vsub.i32 q12, q12, q1 + NoSum: + cmp r3, #0 + beq NoBias + vld1.32 {d4, d5}, [r3] + vadd.i32 q8, q8, q2 + vadd.i32 q12, q12, q2 + + NoBias: + ldr lr, [sp, #48] + cmp lr, #0 + bne PerChannel + ldr lr, [sp, #36] + vld1.32 {d6[], d7[]}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8[], d9[]}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10[], d11[]}, [lr] + b QuantizeStart + PerChannel: + ldr lr, [sp, #36] + vld1.32 {d6, d7}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8, d9}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10, d11}, [lr] + QuantizeStart: + vshl.s32 q8, q8, q3 + vshl.s32 q12, q12, q3 + + vqrdmulh.s32 q8, q8, q4 + vqrdmulh.s32 q12, q12, q4 + + vand q3, q5, q8 + vshr.s32 q3, q3, #31 + vqadd.s32 q8, q8, q3 + vrshl.s32 q8, q8, q5 + vand q4, q5, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q12, q12, q4 + vrshl.s32 q12, q12, q5 + + ldr r10, [sp, #28] + vdup.32 q6, r10 + vadd.i32 q8, q8, q6 + vadd.i32 q12, q12, q6 + + ldr r10, [sp, #20] + vdup.32 q0, r10 + vmax.s32 q8, q8, q0 + vmax.s32 q12, q12, q0 + + ldr r10, [sp, #24] + vdup.32 q1, r10 + vmin.s32 q8, q8, q1 + vmin.s32 q12, q12, q1 + + vqmovn.s32 d30, q8 + vqmovn.s32 d31, q12 + vqmovn.s16 d0, q15 + + // prefetching is not preferred while writing results in spite of cache missing + // you could try prfm pstl2strm + WriteStart: + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + vst1.8 {d0[0]}, [r11], r7 + vst1.8 {d0[1]}, [r11] + add r0, r0, #1 + b WriteEnd + Write2: + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + add r0, r0, #2 + b WriteEnd + Write3: + add r14, r11, #2 + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + vst1.8 {d0[0]}, [r14], r7 + vst1.8 {d0[1]}, [r14] + add r0, r0, #3 + b WriteEnd + Write4: + vst1.32 {d0[0]}, [r11], r7 + vst1.32 {d0[1]}, [r11] + add r0, r0, #4 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + cmp r6, #4 + ble LoopOcEnd + ldr lr, [sp, #48] + cmp lr, #0 + beq NoChannelForward + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSumForward + ldr lr, [sp, #16] + add lr, lr, #16 + str lr, [sp, #16] + NoSumForward: + ldr lr, [sp, #36] + add lr, lr, #16 + str lr, [sp, #36] + ldr lr, [sp, #32] + add lr, lr, #16 + str lr, [sp, #32] + ldr lr, [sp, #40] + add lr, lr, #16 + str lr, [sp, #40] + NoChannelForward: + sub r6, r6, #4 + cmp r3, #0 + beq NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + b LoopOc + +LoopOcEnd: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S new file mode 100644 index 00000000..f45eb0d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatVecMulFp32.S @@ -0,0 +1,195 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulFp32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r10, #4 + mul r10, r10, r5 // stride = depth * sizeof(float) + mov r11, #4 + mul r11, r11, r10 // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #4 + blt Col4Depth1 + + Col4Depth4: + vld1.f32 {q8}, [r7]! + add lr, r9, r10 + vld1.f32 {q0}, [r9]! + vld1.f32 {q1}, [lr], r10 + vld1.f32 {q2}, [lr], r10 + vld1.f32 {q3}, [lr] + + vmla.f32 q9, q8, q0 + vmla.f32 q10, q8, q1 + vmla.f32 q11, q8, q2 + vmla.f32 q12, q8, q3 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + vpadd.f32 d26, d18, d20 + vpadd.f32 d27, d19, d21 + vpadd.f32 d28, d22, d24 + vpadd.f32 d29, d23, d25 + vadd.f32 d30, d26, d27 + vadd.f32 d31, d28, d29 + cmp r8, #0 + beq Col4End + + Col4Depth1: + vld1.f32 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.f32 {d2[0]}, [r9]! + vld1.f32 {d2[1]}, [lr], r10 + vld1.f32 {d3[0]}, [lr], r10 + vld1.f32 {d3[1]}, [lr] + + vmla.f32 q15, q1, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.f32 {q13}, [r3]! + vadd.f32 q15, q15, q13 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i32 q12, #6 + vcvt.f32.s32 q12, q12 + vmin.f32 q15, q15, q12 + + Col4Relu: + veor q13, q13, q13 + vmax.f32 q15, q15, q13 + + Col4Write: + vst1.f32 {q15}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q13, q13, q13 + veor q15, q15, q15 + + cmp r8, #4 + blt Col1Depth1 + + Col1Depth4: + vld1.f32 {q0}, [r7]! + vld1.f32 {q1}, [r9]! + + vmla.f32 q10, q1, q0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + vpadd.f32 d24, d20, d22 + vpadd.f32 d25, d21, d23 + vadd.f32 d30, d24, d25 + cmp r8, #0 + beq Col1End + + Col1Depth1: + vld1.f32 {d0[0]}, [r7]! + vld1.f32 {d2[0]}, [r9]! + + vmla.f32 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.f32 {d28[0]}, [r3]! + vadd.f32 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i32 d26, #6 + vcvt.f32.s32 d26, d26 + vmin.f32 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f32 d30, d30, d24 + + Col1Write: + vst1.f32 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S new file mode 100644 index 00000000..f36fe067 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32.S @@ -0,0 +1,381 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #32 // sizeof(float) * 8 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth + ldr lr, [sp, #24] + cmp lr, #0 + beq NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopCol: + ldr r6, [sp, #8] // reload lhs row + ldr r0, [sp, #-48] // reload lhs ptr + ldr r2, [sp, #-40] // reload dst ptr + + LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r5, [sp, #4] // reload depth + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3] + sub r3, r3, #16 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #24] + cmp lr, #0 + bne WriteWino + ldr lr, [sp, #20] + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + b WriteEnd + Write2: + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + b WriteEnd + Write3: + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + b WriteEnd + Write4: + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + b WriteEnd + Write5: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + b WriteEnd + Write6: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + b WriteEnd + Write7: + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + b WriteEnd + WriteC8: + vst1.32 {q8, q9}, [r2]! + vst1.32 {q10, q11}, [r2]! + vst1.32 {q12, q13}, [r2]! + vst1.32 {q14, q15}, [r2]! + str r2, [sp, #-40] + b WriteEnd + WriteWino: + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + add r2, r2, r11 + b WriteEnd + Write8: + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + + WriteEnd: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + + LoopRowEnd: + ldr r1, [sp, #-44] + add r1, r1, r12 // rhs ptr + stride + str r1, [sp, #-44] + cmp r3, #0 + beq NoBiasStep + add r3, r3, #32 // bias ptr + stride + NoBiasStep: + ldr lr, [sp, #24] + cmp lr, #0 + bne WinoDstStep + ldr lr, [sp, #20] + cmp lr, #0 + beq NoDstStep + ldr r2, [sp, #-40] + add r2, r2, #32 // dst ptr + stride + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + ldr r2, [sp, #-40] + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + +LoopColEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S new file mode 100644 index 00000000..83d6113f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt.S @@ -0,0 +1,422 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32Opt + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #16 // sizeof(float) * 4 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 4 * depth + ldr lr, [sp, #20] + cmp lr, #0 + bne NoC8Steps + mov lr, #32 + mul r10, r6, lr +NoC8Steps: + cmp lr, #2 + bne NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr lr, [sp, #20] + cmp lr, #0 + beq NoReloadDst + ldr r2, [sp, #-40] // reload dst ptr + NoReloadDst: + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmul.f32 q8, q1, d0[0] + vmul.f32 q9, q2, d0[0] + vmul.f32 q10, q1, d0[1] + vmul.f32 q11, q2, d0[1] + vmul.f32 q12, q1, d1[0] + vmul.f32 q13, q2, d1[0] + vmul.f32 q14, q1, d1[1] + vmul.f32 q15, q2, d1[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3]! + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWino + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + Write5: + add lr, r2, #20 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + add r2, r2, #20 + b WriteEnd + Write6: + add lr, r2, #24 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + add r2, r2, #24 + b WriteEnd + Write7: + add lr, r2, #28 + str lr, [sp, #-40] + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + add r2, r2, #28 + b WriteEnd + WriteC8: + mov lr, r2 + vst1.32 {q8, q9}, [lr]! + vst1.32 {q10, q11}, [lr]! + vst1.32 {q12, q13}, [lr]! + vst1.32 {q14, q15}, [lr]! + add r2, r2, r10 + b WriteEnd + WriteWino: + add lr, r2, r10 + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + str lr, [sp, #-40] + b WriteEnd + Write8: + add lr, r2, #32 + str lr, [sp, #-40] + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + add r2, r2, #32 + + WriteEnd: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // rhs ptr + stride + str r0, [sp, #-48] + ldr lr, [sp, #20] + cmp lr, #0 + beq C8DstStep + cmp lr, #2 + beq WinoDstStep + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + b NoDstStep + C8DstStep: + ldr lr, [sp, #-40] + add r2, lr, #128 + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + +LoopRowEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S new file mode 100644 index 00000000..25370099 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulFp32Opt12x4.S @@ -0,0 +1,578 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +asm_function MatmulFloatNeon32Opt12x4 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #48 // sizeof(float) * 12 + mul r12, r5, lr // block stride of lhs: sizeof(float) * 12 * depth + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRowStart: + cmp r6, #4 + ble LoopRow4 + cmp r6, #8 + ble LoopRow8 + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + vmul.f32 q12, q3, d4[0] + vmul.f32 q13, q3, d4[1] + vmul.f32 q14, q3, d5[0] + vmul.f32 q15, q3, d5[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + vmla.f32 q12, q3, d4[0] + vmla.f32 q13, q3, d4[1] + vmla.f32 q14, q3, d5[0] + vmla.f32 q15, q3, d5[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q0 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + b Write + +LoopRow8: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R8: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + subs r5, r5, #1 + beq Bias_R8 + + LoopDepth_R8: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + subs r5, r5, #1 + bne LoopDepth_R8 + + Bias_R8: + cmp r3, #0 + beq Activation_R8 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + + Activation_R8: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R8 + cmp lr, #1 + beq Relu_R8 + b Write + + Relu6_R8: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + + Relu_R8: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + b Write + +LoopRow4: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R4: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + subs r5, r5, #1 + beq Bias_R4 + + LoopDepth_R4: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + subs r5, r5, #1 + bne LoopDepth_R4 + + Bias_R4: + cmp r3, #0 + beq Activation_R4 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + + Activation_R4: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R4 + cmp lr, #1 + beq Relu_R4 + b Write + + Relu6_R4: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + + Relu_R4: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + + Write: + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + b Write4 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d8[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14[0], [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16[0], [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18[0], [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22[0], [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26[0], [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d8, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d8, [r2] + vst1.32 d9[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d10, [r2] + vst1.32 d11[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d12, [r2] + vst1.32 d13[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d14, [r2] + vst1.32 d15[0], [r4] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d18, [r2] + vst1.32 d19[0], [r4] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d22, [r2] + vst1.32 d23[0], [r4] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d26, [r2] + vst1.32 d27[0], [r4] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d30, [r2] + vst1.32 d31[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d8, d9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d10, d11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d12, d13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d14, d15}, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d16, d17}, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d18, d19}, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d22, d23}, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d26, d27}, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d30, d31}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + WriteEnd: + cmp r7, #4 + ble LoopColEnd + sub r7, r7, #4 // rhs col - 4 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // lhs ptr + stride + str r0, [sp, #-48] + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + cmp r6, #12 + ble LoopRowEnd + sub r6, r6, #12 // lhs row - 12 + b LoopRowStart + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S new file mode 100644 index 00000000..6dc036de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8.S @@ -0,0 +1,298 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +// #-52: a, #-48: b, #-44: dst, #-40: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel + +asm_function MatmulInt8Neon32 + push {r0-r11, lr} + vpush {q4-q7} + add sp, sp, #116 + + ldr r4, [sp] // col + ldr r7, [sp, #40] // output stride + mov r8, #0 // output channels offset + ldr r10, [sp, #44] + cmp r10, #0 + beq L1 + ldr r6, [sp, #8] // load intpu_sums ptr if per_channel +L1: + cmp r4, #0 // if at the end of col + ble End1 + + ldr r0, [sp, #-52] // reload a ptr + ldr r3, [sp, #-40] // reset row counter + ldr r10, [sp, #44] + cmp r10, #0 + bne L2 + ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor +L2: + cmp r3, #0 // if at the end of row + ble End2 + + ldr r1, [sp, #-48] // reload b ptr + ldr r5, [sp, #4] // reset deep16 + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 +L3: + cmp r5, #0 + beq End3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vld1.8 {d8, d9, d10, d11}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q10, q14 + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + b L3 + +End3: + vpadd.i32 d0, d12, d13 + vpadd.i32 d1, d14, d15 + vpadd.i32 d2, d16, d17 + vpadd.i32 d3, d18, d19 + vpadd.i32 d4, d20, d21 + vpadd.i32 d5, d22, d23 + vpadd.i32 d6, d24, d25 + vpadd.i32 d7, d26, d27 + + vpadd.i32 d28, d0, d1 + vpadd.i32 d29, d2, d3 + vpadd.i32 d30, d4, d5 + vpadd.i32 d31, d6, d7 + + // Add weight_bias + ldr r9, [sp, #12] // reload weight_bias ptr + add r9, r9, r8 + vld1.32 {d26}, [r9]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + ldr r10, [sp, #44] + cmp r10, #0 + bgt PerChannel + +PerTensor: + // Subtract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + // Apply left shift + ldr r10, [sp, #32] + ldr r11, [r10]! + vdup.32 q9, r11 + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + ldr r11, [r10] + vdup.32 q8, r11 + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + // Apply right shift + ldr r10, [sp, #36] + ldr r11, [r10] + vdup.32 q7, r11 + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b AddDstZP + +PerChannel: + // Subtract input_sums + vld1.32 {d24, d25, d26, d27}, [r6]! + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + // Apply left shift + ldr r10, [sp, #32] + add r10, r10, r8 + vld1.32 {d23}, [r10] + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + add r10, r10, r8 + vld1.32 {d22}, [r10] + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + // Apply right shift + ldr r10, [sp, #36] + add r10, r10, r8 + vld1.32 {d21}, [r10] + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + +AddDstZP: + // Add the destination zero point + ldr r10, [sp, #24] + vdup.32 q4, r10 + vadd.i32 q14, q14, q4 + vadd.i32 q15, q15, q4 + + // Apply the act_min bound + ldr r10, [sp, #16] + vdup.32 q3, r10 + vmax.s32 q14, q14, q3 + vmax.s32 q15, q15, q3 + + // Apply the act_max bound + ldr r10, [sp, #20] + vdup.32 q2, r10 + vmin.s32 q14, q14, q2 + vmin.s32 q15, q15, q2 + + // Cast-and-saturate from int32 to int16 + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + // Cast-and-saturate from int16 to int8 + vqmovn.s16 d30, q14 + + // start to write + cmp r4, #2 + bge WriteCol2 + cmp r4, #1 + beq WriteCol1 + b EndWrite + +WriteCol2: + vst1.16 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.16 {d30[1]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.16 {d30[2]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.16 {d30[3]}, [r2], r7 + b EndWrite + +WriteCol1: + vst1.8 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.8 {d30[2]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.8 {d30[4]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.8 {d30[6]}, [r2], r7 + b EndWrite + +EndWrite: + sub r3, r3, #4 // a row counter -= 4 + b L2 + +End2: + sub r4, r4, #2 // b col counter -= 2 + ldr r1, [sp, #-48] // load b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + ldr r2, [sp, #-44] // load dst ptr + add r2, r2, #2 // dst ptr + offset + str r2, [sp, #-44] + add r8, r8, #8 // output channels offset + 2*sizeof(int) + b L1 + +End1: + sub sp, sp, #116 + vpop {q4-q7} + pop {r0-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S new file mode 100644 index 00000000..16426bfd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulInt8Opt.S @@ -0,0 +1,300 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, +// int *filter_zp); +// #-48: a, #-44: b, #-40: dst, #-36: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp + +asm_function MatmulInt8Opt + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] // reload a_sums ptr + ldr r8, [sp, #40] + mov r10, #4 + mul r10, r10, r5 // lhs step + mov r11, #4 + mul r11, r11, r8 // dst step +LoopRow: + ldr r1, [sp, #-44] //reload rhs ptr + ldr r4, [sp] // reload rhs col + ldr lr, [sp, #-40] + vmov.32 d4[0], lr // reload dst ptr + ldr lr, [sp, #32] + vmov.32 d4[1], lr // reload left shift + ldr lr, [sp, #28] + vmov.32 d5[0], lr // reload multiplier + ldr lr, [sp, #36] + vmov.32 d5[1], lr // reload right shift + ldr r7, [sp, #48] // reload filter_zp + ldr r12, [sp, #12] // reload bias ptr + + LoopCol: + vmov.32 r2, d4[0] // reload dst ptr + ldr r0, [sp, #-48] //reload lhs ptr + ldr r5, [sp, #4] // reaload depth + + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 + + LoopDepth: + vld1.8 {q0-q1}, [r0]! + vld1.8 {q4-q5}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q6, q14 + vpadal.s16 q8, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + vld1.8 {q0-q1}, [r0]! + + vpadal.s16 q7, q14 + vpadal.s16 q9, q15 + + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q10, q14 + vpadal.s16 q12, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + + vpadal.s16 q11, q14 + vpadal.s16 q13, q15 + + cmp r5, #16 + ble LoopDepthEnd + sub r5, r5, #16 + b LoopDepth + + LoopDepthEnd: + vpadd.i32 d12, d12, d13 + vpadd.i32 d14, d14, d15 + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + + vpadd.i32 d28, d12, d14 + vpadd.i32 d29, d16, d18 + vpadd.i32 d30, d20, d22 + vpadd.i32 d31, d24, d26 + + Bias: + cmp r12, #0 + beq NoBias + vld1.32 {d26}, [r12]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + NoBias: + ldr lr, [sp, #44] + cmp lr, #0 + bne PerChannel + + PerTensor: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + vmov.32 lr, d4[1] + vld1.32 {d18[], d19[]}, [lr] + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + vmov.32 lr, d5[0] + vld1.32 {d16[], d17[]}, [lr] + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + vmov.32 lr, d5[1] + vld1.32 {d14[], d15[]}, [lr] + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b Quantize + + PerChannel: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vld1.32 {d19}, [r7]! + vmul.s32 d24, d20, d19 + vmul.s32 d25, d21, d19 + vmul.s32 d26, d22, d19 + vmul.s32 d27, d23, d19 + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + vmov.32 lr, d4[1] + vld1.32 {d23}, [lr]! + vmov.32 d4[1], lr + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + vmov.32 lr, d5[0] + vld1.32 {d22}, [lr]! + vmov.32 d5[0], lr + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + vmov.32 lr, d5[1] + vld1.32 {d21}, [lr]! + vmov.32 d5[1], lr + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + + Quantize: + ldr lr, [sp, #24] + vdup.32 q0, lr + vadd.i32 q14, q14, q0 + vadd.i32 q15, q15, q0 + + ldr lr, [sp, #16] + vdup.32 q1, lr + vmax.s32 q14, q14, q1 + vmax.s32 q15, q15, q1 + + ldr lr, [sp, #20] + vdup.32 q0, lr + vmin.s32 q14, q14, q0 + vmin.s32 q15, q15, q0 + + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + vqmovn.s16 d30, q14 + + cmp r4, #1 + beq Write1 + b Write2 + + Write1: + vmov.32 lr, d4[0] + add lr, lr, #1 + vmov.32 d4[0], lr + vst1.8 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.8 {d30[2]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.8 {d30[4]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.8 {d30[6]}, [r2], r8 + b WriteEnd + + Write2: + vmov.32 lr, d4[0] + add lr, lr, #2 + vmov.32 d4[0], lr + vst1.16 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.16 {d30[1]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.16 {d30[2]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.16 {d30[3]}, [r2], r8 + + WriteEnd: + cmp r4, #2 + ble LoopColEnd + sub r4, r4, #2 + b LoopCol + +LoopColEnd: + cmp r3, #4 + ble LoopRowEnd + ldr lr, [sp, #-48] + add lr, lr, r10 + str lr, [sp, #-48] + ldr lr, [sp, #-40] + add lr, lr, r11 + str lr, [sp, #-40] + sub r3, r3, #4 + add r6, r6, #16 + b LoopRow + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S new file mode 100644 index 00000000..49d30d0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/MatmulWinogradFp32.S @@ -0,0 +1,186 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // r0: matrix_a, r1: matrix_b, r2: matrix_c, r3: m, r4: k, r5: n, r6: in_channel, r7: c4_channel * 4 + // #-56: matrix_a, #-52: matrix_b, #-48: matrix_c, #-44: m, #0: k, #4: n, #8: in_channel, #12: c4_channel * 4 +asm_function MatrixMultiplyWinograd + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r12, lr} + vpush {q4-q7} + add sp, sp, #120 + + mov r0, #4 + ldr r4, [sp, #4] // n + ldr r5, [sp, #8] // in_channel + ldr r6, [sp] // k + mul r5, r5, r0 // in_channel * 4 + mul r4, r4, r0 // n * 4 + mul r6, r6, r5 // in_channel * 4 * k + + // r3 = m + // r2 = dst + LoopM: + ldr r7, [sp, #4] // n + ldr r8, [sp, #-52] // matrix_b + LoopN: + ldr r0, [sp, #4] // n + ldr r1, [sp, #-44] // m + sub r0, r0, r7 // ni + mul r0, r0, r1 // ni * m + sub r1, r1, r3 // mi + add r0, r0, r1 // ni * m + mi + ldr r1, [sp, #12] + mul r9, r0, r1 // (ni * m + mi) * c4_channel * 4 + add r11, r2, r9 // dst + offset + + ldr r10, [sp, #8] // in_channel + ldr r9, [sp, #-56] // src + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + b EndLoopC + + LoopC16: + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + LoopK16: + vld1.32 {q0, q1}, [r9]! + vld1.32 {q2, q3}, [r9]! + add r9, r9, r5 + sub r9, r9, #64 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + vmla.f32 q7, q2, d8[0] + vmla.f32 q8, q3, d8[0] + subs r12, r12, #1 + bne LoopK16 + Write16: + vst1.32 {q5, q6}, [r11]! + vst1.32 {q7, q8}, [r11]! + subs r10, r10, #16 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #64 + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC8: + veor q5, q5, q5 + veor q6, q6, q6 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK8: + vld1.32 {q0, q1}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + subs r12, r12, #1 + bne LoopK8 + Write8: + vst1.32 {q5, q6}, [r11]! + subs r10, r10, #8 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #32 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC4: + veor q5, q5, q5 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK4: + vld1.32 {q0}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + subs r12, r12, #1 + bne LoopK4 + Write4: + vst1.32 {q5}, [r11]! + subs r10, r10, #4 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #16 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC: + veor q2, q2, q2 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK: + vld1.32 d0[0], [r9], r5 + vld1.32 d2[0], [r0], r4 + vmla.f32 s8, s0, s4 + subs r12, r12, #1 + bne LoopK + Write: + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #4 + b LoopC + + EndLoopC: + subs r7, r7, #1 + beq EndLoopN + add r8, r8, #4 + b LoopN + EndLoopN: + subs r3, r3, #1 + beq EndLoopM + ldr r0, [sp, #-56] + add r0, r0, r6 + str r0, [sp, #-56] + b LoopM + EndLoopM: + sub sp, sp, #120 + vpop {q4-q7} + pop {r0-r12, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S new file mode 100644 index 00000000..d22ee866 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC4.S @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradPostFuncBiasReluC4 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + mov lr, #4 + add r12, r3, r4 + mul r12, r12, lr + + mov lr, #0 + +Loop_C4: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #4 + mov r8, r5 + vld1.32 {q12}, [r2]! + +Loop_4x4: + cmp r8, #4 + blt Loop_1x4 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q12 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q12 + + cmp r7, #3 + beq Relu6_4x4 + cmp r7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 +Relu_4x4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 +Write_4x4: + vst1.32 {q0}, [r11], r12 + vst1.32 {q1}, [r11], r12 + vst1.32 {q2}, [r11], r12 + vst1.32 {q3}, [r11], r12 + b Loop_4x4 + +Loop_1x4: + cmp r7, #3 + beq Relu6_1x4 + cmp r7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu6_1x4 +Relu_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu_1x4 +Write_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r11], r12 + b Write_1x4 + +HW_Add: + add r1, r1, r6 + b Loop_C4 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S new file mode 100644 index 00000000..93b860ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PostFuncBiasReluC8.S @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// r0 dst r1 srx r2 bias +// r3 oc8div r4 oc8mod r5 plane_size +// r6 stride r7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// r10 r11 weite loop tmp buf +// r16 relu6 #6; r17 relu #0 +// lr oc8 loop control +// r8 hw loop control + +asm_function PostFuncBiasReluC8 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + mov lr, #0 + +Loop_C8: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #8 + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + +Loop_4x8: + cmp r8, #4 + blt Loop_1x8 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + vld1.32 {q8-q9}, [r1]! + vld1.32 {q10-q11}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q13 + vadd.f32 q8, q8, q12 + vadd.f32 q9, q9, q13 + vadd.f32 q10, q10, q12 + vadd.f32 q11, q11, q13 + + cmp r7, #3 + beq Relu6_4x8 + cmp r7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + vmin.f32 q8, q8, q14 + vmin.f32 q9, q9, q14 + vmin.f32 q10, q10, q14 + vmin.f32 q11, q11, q14 +Relu_4x8: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + vmax.f32 q8, q8, q15 + vmax.f32 q9, q9, q15 + vmax.f32 q10, q10, q15 + vmax.f32 q11, q11, q15 +Write_4x8: + vst1.32 {q0-q1}, [r11], r6 + vst1.32 {q2-q3}, [r11], r6 + vst1.32 {q8-q9}, [r11], r6 + vst1.32 {q10-q11}, [r11], r6 + b Loop_4x8 + +Loop_1x8: + cmp r7, #3 + beq Relu6_1x8 + cmp r7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu6_1x8 +Relu_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu_1x8 +Write_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0-q1}, [r11], r6 + b Write_1x8 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + cmp r4, #4 + beq Loop_C1_4 + cmp r4, #5 + beq Loop_C1_5 + cmp r4, #6 + beq Loop_C1_6 + cmp r4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp r7, #3 + beq Loop_C1_4_Relu6 + cmp r7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Write + +Loop_C1_5: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_5_Relu6 + cmp r7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Write + +Loop_C1_6: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_6_Relu6 + cmp r7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Write + +Loop_C1_7: + add r11, r0, #16 + add r10, r0, #24 + cmp r7, #3 + beq Loop_C1_7_Relu6 + cmp r7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S new file mode 100644 index 00000000..0a557ac3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Peroc.S @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2, +// size_t oc_res2, size_t stride); + +// r0 src +// r1 sum +// r2 zp +// r3 hw4 +// r4 ic16 +// r5 oc_div2 +// r6 oc_res2 +// r7 stride + +asm_function PreSum4x16Int8Peroc + push {r4-r11, lr} + vpush {q4-q7} + add sp, sp, #100 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mov r8, #0 + mov r10, #8 + +RowLoop: + cmp r8, r3 + beq End + add r8, r8, #4 + vmov.s32 q13, #0 + mov r9, #0 + mov r11, r2 + +Sum: + cmp r9, r4 + beq Mul + add r9, r9, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b Sum + +Mul: + mov r12, r1 + add r1, r1, #32 + mov r9, #0 + + vdup.32 d1, d26[0] + vdup.32 d2, d26[1] + vdup.32 d3, d27[0] + vdup.32 d4, d27[1] + +Write: + + cmp r9, r5 + beq OcRes + add r9, r9, #2 + vld1.32 {d9}, [r11]! + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + add r12, r12, r7 + b Write + +OcRes: + cmp r6, #0 + beq RowLoop + + vmov.s32 d9, #0 + vld1.8 {d9[0]}, [r11] + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + b RowLoop + +End: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S new file mode 100644 index 00000000..d0ad50c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/PreSum4x16Int8Pert.S @@ -0,0 +1,94 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); + +// r0 src +// r1 sum +// r2 row4 +// r3 co16 +// r4 filter_zp + +asm_function PreSum4x16Int8Pert + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + + vdup.32 q10, r4 + mov r5, #0 + mov r7, #16 + +RowLoop: + cmp r5, r2 + beq End + add r5, r5, #4 + vmov.s32 q13, #0 + mov r6, #0 + +CalLoop: + cmp r6, r3 + beq Write + add r6, r6, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b CalLoop + +Write: + vmul.i32 q13, q13, q10 + vst1.32 {d26, d27}, [r1], r7 + beq RowLoop + +End: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S new file mode 100644 index 00000000..9c725f71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/TiledC4MatmulFp32.S @@ -0,0 +1,211 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t cal_num, size_t ic4, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +push {r4-r8, lr} +ldr r4, [sp, #24] +ldr r5, [sp, #28] +//step multi by sizeof(float) +mov r8, #4 +mul r3, r8, r3 + +vpush {q4-q7} + +LoopOc: + mov r6, r1 + mov r8, r0 + subs r7, r4, #1 + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + vld1.32 {q4, q5}, [r2]! + vld1.32 {q6, q7}, [r2]! + + vmul.f32 q8, q4, d0[0] + vmul.f32 q9, q4, d2[0] + vmul.f32 q10, q4, d4[0] + vmul.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmul.f32 q12, q4, d0[0] + vmul.f32 q13, q4, d2[0] + vmul.f32 q14, q4, d4[0] + vmul.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + beq LoopIcEnd + + subs r7, r7, #1 + + vld1.32 {q4, q5}, [r2]! + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + beq LoopIcEndHalf + + LoopIc: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vld1.32 {q4, q5}, [r2]! + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q15, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + + subs r7, r7, #1 + bne LoopIc + LoopIcEndHalf: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + LoopIcEnd: + vst1.32 {q8, q9}, [r0]! + vst1.32 {q10, q11}, [r0]! + vst1.32 {q12, q13}, [r0]! + vst1.32 {q14, q15}, [r0]! + mov r1, r6 + + subs r5, r5, #1 + add r0, r8, r3 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S new file mode 100644 index 00000000..75768889 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransLeft.S @@ -0,0 +1,230 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransLeft + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r3, r8 + sub r9, r9, r8 + add r7, r9, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r0, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + add r5, r4, r7 + add r6, r5, r7 + add r7, r6, r7 + + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + add r0, r7, r9 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + add r0, r4, r9 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + add r0, r3, r9 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + + LoopK: + vld1.32 {d30[0]}, [r1], r10 + + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + subs r12, r12, #1 + + sub r2, r2, r8 + add r0, r0, r9 + bne LoopK + + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r0, r0, r8 + add r2, r2, r8 + bne LoopW + + pop {r0, r3} + add r1, r1, #4 //sizeof(float) + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S new file mode 100644 index 00000000..2abb31ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm32/WinogradTransRight.S @@ -0,0 +1,220 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransRight + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r5, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r1, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + add r5, r4, r8 + add r6, r5, r8 + add r7, r6, r8 + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + mov r0, r7 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + mov r0, r4 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + mov r0, r3 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + LoopK: + vld1.32 {d30[0]}, [r1], r10 + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + + subs r12, r12, #1 + sub r2, r2, r8 + bne LoopK + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r2, r2, r8 + add r1, r1, #4 //sizeof(float) + bne LoopW + + pop {r1, r3} + add r0, r0, r9 + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S new file mode 100644 index 00000000..863d3cea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/AdderFp32.S @@ -0,0 +1,622 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function AdderFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x8, [sp, #144] + + mov x20, #48 // sizeof(float) * 12 + mul x17, x5, x20 // block stride of lhs/rhs: sizeof(float) * 12 * depth + + mov x20, #4 + mul x8, x8, x20 + +LoopRowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + blt LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v25.4s, v3.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v27.4s, v3.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v29.4s, v3.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v31.4s, v3.4s, v30.4s + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v24.4s, v3.4s, v24.4s + fadd v25.4s, v25.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v26.4s, v3.4s, v26.4s + fadd v27.4s, v27.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v28.4s, v3.4s, v28.4s + fadd v29.4s, v29.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v30.4s, v3.4s, v30.4s + fadd v31.4s, v31.4s, v30.4s + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + fneg v25.4s, v25.4s + fneg v27.4s, v27.4s + fneg v29.4s, v29.4s + fneg v31.4s, v31.4s + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + fadd v25.4s, v25.4s, v0.4s + fadd v27.4s, v27.4s, v0.4s + fadd v29.4s, v29.4s, v0.4s + fadd v31.4s, v31.4s, v0.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + Relu8: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v2.4s + fmax v11.4s, v11.4s, v2.4s + fmax v13.4s, v13.4s, v2.4s + fmax v15.4s, v15.4s, v2.4s + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + b Write4 + + Write1: + add x2, x2, #4 + str s9, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s11, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s13, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s15, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s17, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s19, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s21, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s25, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s27, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s29, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s31, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v9.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + + WriteEnd: + subs x13, x13, #4 // rhs col - 4 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + mov x20, #4 + mul x20, x20, x7 + sub x11, x11, x20 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S new file mode 100644 index 00000000..6097397d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/BigMatmulFp32Opt.S @@ -0,0 +1,2528 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void BigMatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride + +asm_function BigMatmulFloatNeon64Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr x8, [sp, #224] + mov x20, #1 + mov x22, #32 + mov x23, #48 + mul x26, x5, x23 // stride for lhs + mul x24, x8, x23 // stride for out + lsl x27, x23, #9 // stride by depth for lhs + lsl x28, x22, #9 // stride by depth for rhs + lsl x22, x5, #5 // stride for rhs + lsl x8, x8, #2 + subs x5, x5, #512 + ble DepthTail +Depth512: + mov x25, x0 // restore lhs + mov x13, x2 // out + mov x10, x6 // restore row + RowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + + LoopRow12: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol12x4 + + LoopCol12x8: + cbz x20, Reload12x8 + cbnz x15, InitFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b Compute12x8Enter + Reload12x8: + bl Reload + Compute12x8Enter: + cbz x21, Write + bl Compute12x8Unit + b Write + + LoopCol12x4: + cbz x20, Reload12x4 + cbnz x15, InitFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b Compute12x4Enter + Reload12x4: + bl Reload + Compute12x4Enter: + cbz x21, Write + bl Compute12x4Unit + b Write + + LoopRow8: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol8x4 + + LoopCol8x8: + cbz x20, Reload8x8 + cbnz x15, InitFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b Compute8x8Enter + Reload8x8: + bl Reload + Compute8x8Enter: + cbz x21, Write + bl Compute8x8Unit + b Write + + LoopCol8x4: + cbz x20, Reload8x4 + cbnz x15, InitFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b Compute8x4Enter + Reload8x4: + bl Reload + Compute8x4Enter: + cbz x21, Write + bl Compute8x4Unit + b Write + + LoopRow4: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol4x4 + + LoopCol4x8: + cbz x20, Reload4x8 + cbnz x15, InitFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b Compute4x8Enter + Reload4x8: + bl Reload + Compute4x8Enter: + cbz x21, Write + bl Compute4x8Unit + b Write + + LoopCol4x4: + cbz x20, Reload4x4 + cbnz x15, InitFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b Compute4x4Enter + Reload4x4: + bl Reload + Compute4x4Enter: + cbz x21, Write + bl Compute4x4Unit + +Write: + mov x21, x14 + cmp x9, #1 + beq Write1 + cmp x9, #2 + beq Write2 + cmp x9, #3 + beq Write3 + cmp x9, #4 + beq Write4 + cmp x9, #5 + beq Write5 + cmp x9, #6 + beq Write6 + cmp x9, #7 + beq Write7 + b Write8 + + Write1: + str s8, [x21] + cmp x10, #1 + beq LoopCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopCol + add x21, x21, x8 + str s30, [x21] + b LoopCol + Write2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopCol + Write3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopCol + Write4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopCol + Write5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopCol + Write6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopCol + Write7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopCol + + Write8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopCol + +LoopCol: + subs x9, x9, #8 + ble LoopColEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 // update out + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + b LoopRow12 + +LoopColEnd: + add x25, x25, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt RowStart + mov x20, #0 + add x0, x0, x27 // update lhs by depth + add x1, x1, x28 // update rhs by depth + subs x5, x5, #512 + bgt Depth512 + +/////////////////////////////////////////////////////// + +DepthTail: + add x5, x5, #512 + mov x13, x2 // out + mov x10, x6 + TailRowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + + LoopTailRow12: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol12x4 + + LoopTailCol12x8: + cbz x20, ReloadTail12x8 + cbnz x15, InitTailFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b ComputeTail12x8Enter + InitTailFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b ComputeTail12x8Enter + ReloadTail12x8: + bl Reload + ComputeTail12x8Enter: + cbz x21, Activation12x8 + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b WriteTail + + Relu612x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu12x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b WriteTail + + LoopTailCol12x4: + cbz x20, ReloadTail12x4 + cbnz x15, InitTailFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b ComputeTail12x4Enter + InitTailFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b ComputeTail12x4Enter + ReloadTail12x4: + bl Reload + ComputeTail12x4Enter: + cbz x21, Activation12x4 + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b WriteTail + + Relu612x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + Relu12x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b WriteTail + + LoopTailRow8: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol8x4 + + LoopTailCol8x8: + cbz x20, ReloadTail8x8 + cbnz x15, InitTailFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b ComputeTail8x8Enter + InitTailFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b ComputeTail8x8Enter + ReloadTail8x8: + bl Reload + ComputeTail8x8Enter: + cbz x21, Activation8x8 + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b WriteTail + + Relu68x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b WriteTail + + LoopTailCol8x4: + cbz x20, ReloadTail8x4 + cbnz x15, InitTailFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b ComputeTail8x4Enter + InitTailFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b ComputeTail8x4Enter + ReloadTail8x4: + bl Reload + ComputeTail8x4Enter: + cbz x21, Activation8x4 + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b WriteTail + + Relu68x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + Relu8x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b WriteTail + + LoopTailRow4: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol4x4 + + LoopTailCol4x8: + cbz x20, ReloadTail4x8 + cbnz x15, InitTailFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b ComputeTail4x8Enter + InitTailFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b ComputeTail4x8Enter + ReloadTail4x8: + bl Reload + ComputeTail4x8Enter: + cbz x21, Activation4x8 + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b WriteTail + + Relu64x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b WriteTail + + LoopTailCol4x4: + cbz x20, ReloadTail4x4 + cbnz x15, InitTailFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b ComputeTail4x4Enter + InitTailFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b ComputeTail4x4Enter + ReloadTail4x4: + bl Reload + ComputeTail4x4Enter: + cbz x21, Activation4x4 + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b WriteTail + + Relu64x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + Relu4x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + +WriteTail: + mov x21, x14 + cmp x9, #1 + beq WriteTail1 + cmp x9, #2 + beq WriteTail2 + cmp x9, #3 + beq WriteTail3 + cmp x9, #4 + beq WriteTail4 + cmp x9, #5 + beq WriteTail5 + cmp x9, #6 + beq WriteTail6 + cmp x9, #7 + beq WriteTail7 + b WriteTail8 + + WriteTail1: + str s8, [x21] + cmp x10, #1 + beq LoopTailCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopTailCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopTailCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopTailCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopTailCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopTailCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopTailCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopTailCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopTailCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopTailCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopTailCol + add x21, x21, x8 + str s30, [x21] + b LoopTailCol + WriteTail2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopTailCol + WriteTail3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopTailCol + WriteTail4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopTailCol + WriteTail5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopTailCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopTailCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopTailCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopTailCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopTailCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopTailCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopTailCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopTailCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopTailCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopTailCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopTailCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopTailCol + WriteTail6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopTailCol + WriteTail7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopTailCol + + WriteTail8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopTailCol + +LoopTailCol: + subs x9, x9, #8 + ble LoopTailEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + b LoopTailRow12 + +LoopTailEnd: + add x0, x0, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt TailRowStart + b End + +Reload: + mov x15, x14 + cmp x9, #1 + beq Reload1 + cmp x9, #2 + beq Reload2 + cmp x9, #3 + beq Reload3 + cmp x9, #4 + beq Reload4 + cmp x9, #5 + beq Reload5 + cmp x9, #6 + beq Reload6 + cmp x9, #7 + beq Reload7 + b Reload8 + + Reload1: + ldr s8, [x15] + cmp x10, #1 + beq ReloadEnd + add x15, x15, x8 + ldr s10, [x15] + cmp x10, #2 + beq ReloadEnd + add x15, x15, x8 + ldr s12, [x15] + cmp x10, #3 + beq ReloadEnd + add x15, x15, x8 + ldr s14, [x15] + cmp x10, #4 + beq ReloadEnd + add x15, x15, x8 + ldr s16, [x15] + cmp x10, #5 + beq ReloadEnd + add x15, x15, x8 + ldr s18, [x15] + cmp x10, #6 + beq ReloadEnd + add x15, x15, x8 + ldr s20, [x15] + cmp x10, #7 + beq ReloadEnd + add x15, x15, x8 + ldr s22, [x15] + cmp x10, #8 + beq ReloadEnd + add x15, x15, x8 + ldr s24, [x15] + cmp x10, #9 + beq ReloadEnd + add x15, x15, x8 + ldr s26, [x15] + cmp x10, #10 + beq ReloadEnd + add x15, x15, x8 + ldr s28, [x15] + cmp x10, #11 + beq ReloadEnd + add x15, x15, x8 + ldr s30, [x15] + b ReloadEnd + Reload2: + ld1 {v8.2s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + add x15, x15, #8 + b ReloadEnd + Reload3: + add x19, x15, #8 + ld1 {v8.2s}, [x15], x8 + ld1 {v8.s}[2], [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + ld1 {v10.s}[2], [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + ld1 {v12.s}[2], [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + ld1 {v14.s}[2], [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + ld1 {v16.s}[2], [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + ld1 {v18.s}[2], [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + ld1 {v20.s}[2], [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + ld1 {v22.s}[2], [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + ld1 {v24.s}[2], [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + ld1 {v26.s}[2], [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + ld1 {v28.s}[2], [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + ld1 {v30.s}[2], [x19] + add x15, x15, #12 + b ReloadEnd + Reload4: + ld1 {v8.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + add x15, x15, #16 + b ReloadEnd + Reload5: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ldr s9, [x19] + cmp x10, #1 + beq ReloadEnd + add x19, x19, x8 + ld1 {v10.4s}, [x15], x8 + ldr s11, [x19] + cmp x10, #2 + beq ReloadEnd + add x19, x19, x8 + ld1 {v12.4s}, [x15], x8 + ldr s13, [x19] + cmp x10, #3 + beq ReloadEnd + add x19, x19, x8 + ld1 {v14.4s}, [x15], x8 + ldr s15, [x19] + cmp x10, #4 + beq ReloadEnd + add x19, x19, x8 + ld1 {v16.4s}, [x15], x8 + ldr s17, [x19] + cmp x10, #5 + beq ReloadEnd + add x19, x19, x8 + ld1 {v18.4s}, [x15], x8 + ldr s19, [x19] + cmp x10, #6 + beq ReloadEnd + add x19, x19, x8 + ld1 {v20.4s}, [x15], x8 + ldr s21, [x19] + cmp x10, #7 + beq ReloadEnd + add x19, x19, x8 + ld1 {v22.4s}, [x15], x8 + ldr s23, [x19] + cmp x10, #8 + beq ReloadEnd + add x19, x19, x8 + ld1 {v24.4s}, [x15], x8 + ldr s25, [x19] + cmp x10, #9 + beq ReloadEnd + add x19, x19, x8 + ld1 {v26.4s}, [x15], x8 + ldr s27, [x19] + cmp x10, #10 + beq ReloadEnd + add x19, x19, x8 + ld1 {v28.4s}, [x15], x8 + ldr s29, [x19] + cmp x10, #11 + beq ReloadEnd + add x19, x19, x8 + ld1 {v30.4s}, [x15], x8 + ldr s31, [x19] + add x15, x15, #20 + b ReloadEnd + Reload6: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + add x15, x15, #24 + b ReloadEnd + Reload7: + add x19, x15, #16 + add x16, x15, #24 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + ld1 {v9.s}[2], [x16], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + ld1 {v11.s}[2], [x16], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + ld1 {v13.s}[2], [x16], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + ld1 {v15.s}[2], [x16], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + ld1 {v17.s}[2], [x16], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + ld1 {v19.s}[2], [x16], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + ld1 {v21.s}[2], [x16], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + ld1 {v23.s}[2], [x16], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + ld1 {v25.s}[2], [x16], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + ld1 {v27.s}[2], [x16], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + ld1 {v29.s}[2], [x16], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + ld1 {v31.s}[2], [x16] + add x15, x15, #28 + b ReloadEnd + + Reload8: + ld1 {v8.4s, v9.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s, v11.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s, v13.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s, v15.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s, v17.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s, v19.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s, v21.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s, v23.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s, v25.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s, v27.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s, v29.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s, v31.4s}, [x15], x8 + add x15, x15, #32 + b ReloadEnd + +ReloadEnd: + ret + +Compute12x8Unit: + subs x21, x21, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x21, x21, #2 + bgt Compute12x8 + + Compute12x8End: + cbnz x21, Compute12x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + Compute12x8End1: + ld1 {v1.4s, v2.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ret + +Compute12x4Unit: + subs x21, x21, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute12x4 + + Compute12x4End: + cbnz x21, Compute12x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute12x4End1: + ld1 {v1.4s, v2.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ret + +Compute8x8Unit: + subs x21, x21, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x21, x21, #2 + bgt Compute8x8 + + Compute8x8End: + cbnz x21, Compute8x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + Compute8x8End1: + ld1 {v1.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ret + +Compute8x4Unit: + subs x21, x21, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute8x4 + + Compute8x4End: + cbnz x21, Compute8x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute8x4End1: + ld1 {v1.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + ret + +Compute4x8Unit: + subs x21, x21, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x8 + + Compute4x8End: + cbnz x21, Compute4x8End1 + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + Compute4x8End1: + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ret + +Compute4x4Unit: + subs x21, x21, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x4 + + Compute4x4End: + cbnz x21, Compute4x4End1 + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + Compute4x4End1: + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S new file mode 100644 index 00000000..c03863ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Corner.S @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 // weight + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S new file mode 100644 index 00000000..b828c30b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Horizontal.S @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + add x15, x11, x4 + ld1 {v2.4s}, [x11], x5 + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + add x15, x11, x4 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S new file mode 100644 index 00000000..33958568 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride1.S @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride1 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v23.4s, #6 + scvtf v23.4s, v23.4s + dup v24.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v13.4s}, [x12], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v17.4s}, [x13], x15 + ld1 {v18.4s}, [x13], x15 + ld1 {v19.4s}, [x13], x15 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + ld1 {v16.4s}, [x12] + fmla v22.4s, v0.4s, v10.4s + ld1 {v20.4s}, [x13] + add x1, x1, x14 + fmla v21.4s, v1.4s, v10.4s + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + fmla v22.4s, v1.4s, v11.4s + ld1 {v10.4s}, [x11], x15 + fmla v21.4s, v2.4s, v11.4s + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + ld1 {v11.4s}, [x11], x15 + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + ld1 {v13.4s}, [x12], x15 + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + ld1 {v14.4s}, [x12], x15 + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + ld1 {v15.4s}, [x12], x15 + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + ld1 {v17.4s}, [x13], x15 + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + ld1 {v18.4s}, [x13], x15 + fmla v22.4s, v8.4s, v20.4s + ld1 {v19.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_WRITE: + st1 {v21.4s}, [x0], x16 + ld1 {v21.4s}, [x3] + st1 {v22.4s}, [x0], x16 + ld1 {v22.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + fmla v22.4s, v0.4s, v10.4s + fmla v21.4s, v1.4s, v10.4s + ld1 {v16.4s}, [x12] + fmla v22.4s, v1.4s, v11.4s + fmla v21.4s, v2.4s, v11.4s + ld1 {v20.4s}, [x13] + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + fmla v22.4s, v8.4s, v20.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_LEFT_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_LEFT_WRITE: + st1 {v21.4s}, [x0], x16 + st1 {v22.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v21.4s, v0.4s, v9.4s + fmla v21.4s, v1.4s, v10.4s + fmla v21.4s, v2.4s, v11.4s + fmla v21.4s, v3.4s, v13.4s + fmla v21.4s, v4.4s, v14.4s + fmla v21.4s, v5.4s, v15.4s + fmla v21.4s, v6.4s, v17.4s + fmla v21.4s, v7.4s, v18.4s + fmla v21.4s, v8.4s, v19.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v21.4s, v21.4s, v23.4s + WIDTH1_RELU: + fmax v21.4s, v21.4s, v24.4s + WIDTH1_WRITE: + st1 {v21.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S new file mode 100644 index 00000000..b3d90e05 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Stride2.S @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride2 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v16.4s}, [x12], x15 + ld1 {v19.4s}, [x13], x15 + ld1 {v20.4s}, [x13], x15 + ld1 {v21.4s}, [x13], x15 + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + mov v9.16b, v13.16b + fmla v25.4s, v3.4s, v16.4s + mov v14.16b, v18.16b + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + ld1 {v10.4s}, [x11], x15 + fmla v24.4s, v5.4s, v16.4s + ld1 {v11.4s}, [x11], x15 + fmla v25.4s, v5.4s, v18.4s + ld1 {v15.4s}, [x12], x15 + fmla v24.4s, v6.4s, v19.4s + ld1 {v16.4s}, [x12], x15 + fmla v25.4s, v6.4s, v21.4s + mov v19.16b, v23.16b + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + ld1 {v20.4s}, [x13], x15 + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + ld1 {v21.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_WRITE: + st1 {v24.4s}, [x0], x16 + ld1 {v24.4s}, [x3] + st1 {v25.4s}, [x0], x16 + ld1 {v25.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + fmla v25.4s, v3.4s, v16.4s + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + fmla v24.4s, v5.4s, v16.4s + fmla v25.4s, v5.4s, v18.4s + fmla v24.4s, v6.4s, v19.4s + fmla v25.4s, v6.4s, v21.4s + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_LEFT_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_LEFT_WRITE: + st1 {v24.4s}, [x0], x16 + st1 {v25.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v24.4s, v0.4s, v9.4s + fmla v24.4s, v1.4s, v10.4s + fmla v24.4s, v2.4s, v11.4s + fmla v24.4s, v3.4s, v14.4s + fmla v24.4s, v4.4s, v15.4s + fmla v24.4s, v5.4s, v16.4s + fmla v24.4s, v6.4s, v19.4s + fmla v24.4s, v7.4s, v20.4s + fmla v24.4s, v8.4s, v21.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v24.4s, v24.4s, v26.4s + WIDTH1_RELU: + fmax v24.4s, v24.4s, v27.4s + WIDTH1_WRITE: + st1 {v24.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S new file mode 100644 index 00000000..dbacdb46 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Fp32Vertical.S @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + ld1 {v1.4s}, [x9], x5 + add x12, x2, x14 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S new file mode 100644 index 00000000..9c5237e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8.S @@ -0,0 +1,500 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max, +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Neon64 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + add w21, w4, w4 // col_size * 2 + dup v25.8b, w9 + + cbnz w23, PER_CHANNEL_DUMP + PER_LAYER_DUMP: + ld1r {v27.4s}, [x11] // out_multiplier + ld1r {v26.4s}, [x12] // left_shift + ld1r {v28.4s}, [x13] // right_shift + b MAIN_FUC + PER_CHANNEL_DUMP: + ld1 {v27.4s}, [x11] + ld1 {v26.4s}, [x12] + ld1 {v28.4s}, [x13] + MAIN_FUC: + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + ldr w24, [x12] + + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ld1 {v11.8b}, [x16], x4 + ld1 {v13.8b}, [x17], x4 + ld1 {v14.8b}, [x17], x4 + ld1 {v15.8b}, [x17], x4 + ld1 {v17.8b}, [x25], x4 + ld1 {v18.8b}, [x25], x4 + ld1 {v19.8b}, [x25], x4 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x19] + ld1 {v23.4s}, [x3] + ld1 {v24.4s}, [x19] + + // subtract input zp + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v21.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16] + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + ld1 {v20.8b}, [x25] + add x1, x1, x21 + ssubl v12.8h, v12.8b, v25.8b + smlal v21.4s, v1.4h, v10.4h + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v9.8b}, [x16], x4 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v1.4h, v11.4h + ld1 {v10.8b}, [x16], x4 + ssubl v20.8h, v20.8b, v25.8b + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + ld1 {v13.8b}, [x17], x4 + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ld1 {v14.8b}, [x17], x4 + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + ld1 {v17.8b}, [x25], x4 + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + ld1 {v18.8b}, [x25], x4 + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + ld1 {v19.8b}, [x25], x4 + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST1 + cbz w24, SKIP_LEFTSHIFT1 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP1 + +PER_CHANNEL_POST1: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v26.4s}, [x12] + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + + ld1 {v27.4s}, [x11] + ld1 {v28.4s}, [x13] + +OUTZP1: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + ld1 {v22.4s}, [x19] + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + ld1 {v24.4s}, [x19] + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + ld1 {v21.4s}, [x3] + st1 {v23.8b}, [x0], x6 + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ld1 {v23.4s}, [x3] + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v12.8b}, [x16] + ssubl v12.8h, v12.8b, v25.8b + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v1.4h, v11.4h + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v20.8b}, [x25] + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ssubl v20.8h, v20.8b, v25.8b + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST2 + cbz w24, SKIP_LEFTSHIFT2 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP2 + +PER_CHANNEL_POST2: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + +OUTZP2: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + st1 {v23.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + + cbnz w23, PER_CHANNEL_POST3 + cbz w24, SKIP_LEFTSHIFT3 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + b OUTZP3 + +PER_CHANNEL_POST3: + sqshl v21.4s, v21.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + +OUTZP3: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v21.8b, v21.8h + st1 {v21.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S new file mode 100644 index 00000000..3af1f985 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Corner.S @@ -0,0 +1,222 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S new file mode 100644 index 00000000..88b2e2be --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Horizontal.S @@ -0,0 +1,255 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #48] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #56] // out_multiplier + ldr x10, [sp, #64] // left_shift + ldr x11, [sp, #72] // right_shift + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #96] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + + mov x12, #2 + mul x23, x6, x12 // x6 * 2 + mov x12, #3 + mul x24, x23, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x23 // weight + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x23 + add x21, x19, x4 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + add x22, x20, x24 + ld1 {v6.8h}, [x20], x23 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x23 + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x22], x23 + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x23 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x23 + smlal2 v24.4s, v2.8h, v6.8h + + add x21, x19, x4 + add x22, x20, x24 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x23 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x23 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x22], x23 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S new file mode 100644 index 00000000..0209dfe7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Stride2.S @@ -0,0 +1,474 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Stride2 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + dup v28.8b, w9 // in_zp + ldr w24, [x12] + + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ssubl v9.8h, v9.8b, v28.8b + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + ld1 {v14.8b}, [x17], x4 + ssubl v11.8h, v11.8b, v28.8b + ld1 {v15.8b}, [x17], x4 + ssubl v14.8h, v14.8b, v28.8b + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + ld1 {v19.8b}, [x25], x4 + ssubl v16.8h, v16.8b, v28.8b + ld1 {v20.8b}, [x25], x4 + ssubl v19.8h, v19.8b, v28.8b + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + ssubl v21.8h, v21.8b, v28.8b + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x19] + ld1 {v26.4s}, [x3] + ld1 {v27.4s}, [x19] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + mov v9.16b, v13.16b + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + ld1 {v10.8b}, [x16], x4 + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + smlal v26.4s, v2.4h, v13.4h + ssubl v11.8h, v11.8b, v28.8b + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + mov v14.16b, v18.16b + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + smlal v26.4s, v5.4h, v18.4h + ssubl v16.8h, v16.8b, v28.8b + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + mov v19.16b, v23.16b + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + ld1 {v20.8b}, [x25], x4 + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + smlal v26.4s, v8.4h, v23.4h + ssubl v21.8h, v21.8b, v28.8b + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST1 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT1 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP1 + +PER_CHANNEL_POST1: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP1: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + ld1 {v25.4s}, [x19] + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + ld1 {v27.4s}, [x19] + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + ld1 {v24.4s}, [x3] + st1 {v26.8b}, [x0], x6 + ld1 {v26.4s}, [x3] + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v26.4s, v2.4h, v13.4h + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v26.4s, v5.4h, v18.4h + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + smlal v26.4s, v8.4h, v23.4h + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST2 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT2 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP2 + +PER_CHANNEL_POST2: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP2: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + st1 {v26.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v24.4s, v0.4h, v9.4h + smlal2 v25.4s, v0.8h, v9.8h + smlal v24.4s, v1.4h, v10.4h + smlal2 v25.4s, v1.8h, v10.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + + cbnz w23, PER_CHANNEL_POST3 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT3 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + b OUTZP3 + +PER_CHANNEL_POST3: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + +OUTZP3: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v24.8b, v24.8h + st1 {v24.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S new file mode 100644 index 00000000..a0c2ca54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Int8Vertical.S @@ -0,0 +1,245 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x13], x21 + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x21 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x13], x21 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S new file mode 100644 index 00000000..3a0f8af6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDw3x3Line.S @@ -0,0 +1,203 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, +// bool relu, bool relu6) + +// x0: dst, x1: lines, x2: weight, x3: bias, x4: width, x5: ori_channel, x6: relu, x7: relu6 +asm_function ConvDw3x3Line + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr x8, [x1] + ldr x9, [x1, #8] + ldr x10, [x1, #16] + mov x11, x5 + mov x16, #4 + mul x16, x5, x16 + + mov w14, #6 + dup v30.4s, w14 + scvtf v30.4s, v30.4s + + LoopC4: + cbz x3, NoBias + ld1 {v31.4s}, [x3], #16 + NoBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + mov x12, x0 + mov x13, x4 + + cmp x13, #2 + blt LoopOwRemain + LoopOw2: + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x8], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x9], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + fmul v27.4s, v15.4s, v3.4s + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + fmla v27.4s, v19.4s, v7.4s + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + fmla v27.4s, v23.4s, v11.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + fsub v29.4s, v27.4s, v26.4s + fadd v29.4s, v29.4s, v25.4s + + cbz x3, Activation + Bias: + fadd v28.4s, v28.4s, v31.4s + fadd v29.4s, v29.4s, v31.4s + + Activation: + cbnz x7, Relu6 + cbnz x6, Relu + b Write + Relu6: + fmin v28.4s, v28.4s, v30.4s + fmin v29.4s, v29.4s, v30.4s + Relu: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + fmax v29.4s, v29.4s, v27.4s + Write: + add x15, x12, x16 + cmp x11, #4 + bge Write4 + cmp x11, #3 + beq Write3 + cmp x11, #2 + beq Write2 + cmp x11, #1 + beq Write1 + + Write1: + str s28, [x12] + str s29, [x15] + b WriteEnd + Write2: + st1 {v28.2s}, [x12] + st1 {v29.2s}, [x15] + b WriteEnd + Write3: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + st1 {v29.2s}, [x15] + add x18, x15, #8 + st1 {v29.s}[2], [x18] + b WriteEnd + Write4: + st1 {v28.4s}, [x12] + st1 {v29.4s}, [x15] + + WriteEnd: + add x12, x15, x16 + sub x13, x13, #2 + cmp x13, #2 + bge LoopOw2 + cmp x13, #0 + beq LoopOwEnd + + LoopOwRemain: + ld1 {v12.4s, v13.4s, v14.4s}, [x8] + add x8, x8, #64 + ld1 {v16.4s, v17.4s, v18.4s}, [x9] + add x9, x9, #64 + ld1 {v20.4s, v21.4s, v22.4s}, [x10] + add x10, x10, #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + + cbz x3, ActivationRemain + BiasRemain: + fadd v28.4s, v28.4s, v31.4s + + ActivationRemain: + cbnz x7, Relu6Remain + cbnz x6, ReluRemain + b WriteRemain + Relu6Remain: + fmin v28.4s, v28.4s, v30.4s + ReluRemain: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + WriteRemain: + cmp x11, #4 + bge Write4Remain + cmp x11, #3 + beq Write3Remain + cmp x11, #2 + beq Write2Remain + cmp x11, #1 + beq Write1Remain + + Write1Remain: + str s28, [x12] + b LoopOwEnd + Write2Remain: + st1 {v28.2s}, [x12] + b LoopOwEnd + Write3Remain: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + b LoopOwEnd + Write4Remain: + st1 {v28.4s}, [x12] + + LoopOwEnd: + subs x11, x11, #4 + add x0, x0, #16 + bgt LoopC4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S new file mode 100644 index 00000000..5f4744dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.4s}, [x3] // bias + movi v1.4s, #6 // relu 6 + scvtf v1.4s, v1.4s + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.4s}, [x15], x7 + ld1 {v4.4s}, [x16], #16 + fmla v0.4s, v3.4s, v4.4s + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v1.4s + Relu: + fmax v0.4s, v0.4s, v2.4s + Write: + st1 {v0.4s}, [x0] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S new file mode 100644 index 00000000..568c1a33 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Center.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v8.4s, v16.4s, v25.4s + fmla v9.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v10.4s, v18.4s, v25.4s + fmla v11.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v12.4s, v20.4s, v25.4s + fmla v13.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v14.4s, v22.4s, v25.4s + fmla v15.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.4s}, [x22], x13 + ld1 {v25.4s}, [x17], #16 + fmla v0.4s, v16.4s, v25.4s + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S new file mode 100644 index 00000000..aafde321 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect3x3.S @@ -0,0 +1,159 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect3x3 + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + ldr x8, [sp, #32] + cmp x5, #0 + beq End + + LoopPixel: + ldp x12, x13, [x1] + ldp x14, x15, [x1, #16] + ldp x16, x17, [x1, #32] + ldp x21, x19, [x1, #48] + ldr x20, [x1, #64] + mov x9, x2 + mov x10, x3 + mov x11, x4 + + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x13], #16 + ld1 {v2.4s}, [x14], #16 + + ld1 {v17.4s}, [x9], #16 + ld1 {v18.4s}, [x9], #16 + ld1 {v19.4s}, [x9], #16 + + ld1 {v29.4s}, [x10], #16 + cmp x11, #4 + ble LeftLoop + LoopC4: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + ld1 {v0.4s}, [x12], #16 + ld1 {v17.4s}, [x9], #16 + fmla v29.4s, v7.4s, v24.4s + ld1 {v1.4s}, [x13], #16 + ld1 {v18.4s}, [x9], #16 + fmla v29.4s, v16.4s, v25.4s + ld1 {v2.4s}, [x14], #16 + ld1 {v19.4s}, [x9], #16 + + cbnz x8, Relu6 + cbnz x7, Relu + b Write + Relu6: + fmin v29.4s, v29.4s, v31.4s + Relu: + fmax v29.4s, v29.4s, v30.4s + Write: + st1 {v29.4s}, [x0], #16 + + ld1 {v29.4s}, [x10], #16 + sub x11, x11, #4 + cmp x11, #4 + bgt LoopC4 + + LeftLoop: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + fmla v29.4s, v7.4s, v24.4s + fmla v29.4s, v16.4s, v25.4s + + cbnz x8, LeftRelu6 + cbnz x7, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x11, #4 + bne Write3 + st1 {v29.4s}, [x0], #16 + b NextPixel + Write3: + sxtw x11, w11 + tbnz w11, #1, Write2 + tbnz w11, #0, Write1 + Write2: + st1 {v29.2s}, [x0], #8 + ext v29.16b, v29.16b, v29.16b, #8 + tbz w11, #0, NextPixel + Write1: + str s29, [x0], #4 + + NextPixel: + add x1, x1, x6 + sub x5, x5, #1 + cmp x5, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 +ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S new file mode 100644 index 00000000..87f48ac3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Indirect5x5.S @@ -0,0 +1,304 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect5x5 + sub sp, sp, #176 + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + stp x27, x28, [sp, #128] + stp x29, x30, [sp, #144] + ldrb w8, [sp, #176] + stp x2, x3, [sp] + stp x4, x6, [sp, #16] + stp x7, x8, [sp, #32] + stp x0, x1, [sp, #160] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + mov x3, x5 + cmp x3, #0 + beq End + + LoopPixel: + ldp x5, x4, [sp] // weight, bias + ld1 {v29.4s}, [x4], #16 + ldr x2, [sp, #16] // channel + + ldp x6, x7, [x1] + ldp x8, x9, [x1, #16] + ldp x10, x11, [x1, #32] + ldp x12, x13, [x1, #48] + ldp x14, x15, [x1, #64] + ldp x16, x17, [x1, #80] + ldp x0, x19, [x1, #96] + ldp x20, x21, [x1, #112] + ldp x22, x23, [x1, #128] + ldp x24, x25, [x1, #144] + ldp x26, x27, [x1, #160] + ldp x28, x29, [x1, #176] + ldr x30, [x1, #192] + + ld1 {v0.4s}, [x6], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x8], #16 + ld1 {v3.4s}, [x9], #16 + ld1 {v4.4s}, [x10], #16 + + ld1 {v18.4s}, [x5], #16 + ld1 {v19.4s}, [x5], #16 + ld1 {v20.4s}, [x5], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v22.4s}, [x5], #16 + stp x5, x4, [sp, #48] + + cmp x2, #4 + ble LeftLoop + LoopC4: + ldr x5, [sp, #48] + // column 0 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + ld1 {v0.4s}, [x6], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v1.4s}, [x7], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v2.4s}, [x8], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v3.4s}, [x9], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v4.4s}, [x10], #16 + ld1 {v22.4s}, [x5], #16 + str x5, [sp, #48] + + ldp x4, x5, [sp, #32] + cbnz x5, RELU6 + cbnz x4, RELU + b WRITE + RELU6: + fmin v29.4s, v29.4s, v31.4s + RELU: + fmax v29.4s, v29.4s, v30.4s + WRITE: + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + + ldr x4, [sp, #56] + ld1 {v29.4s}, [x4], #16 + str x4, [sp, #56] + sub x2, x2, #4 + cmp x2, #4 + bgt LoopC4 + + LeftLoop: + // column 0 + ldr x5, [sp, #48] + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + fmla v29.4s, v2.4s, v20.4s + fmla v29.4s, v3.4s, v21.4s + fmla v29.4s, v4.4s, v22.4s + + ldp x4, x5, [sp, #32] + cbnz x5, LeftRelu6 + cbnz x4, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x2, #4 + bne Write3 + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + b NextPixel + Write3: + sxtw x2, w2 + tbnz w2, #1, Write2 + tbnz w2, #0, Write1 + Write2: + ldr x4, [sp, #160] + st1 {v29.2s}, [x4], #8 + str x4, [sp, #160] + ext v29.16b, v29.16b, v29.16b, #8 + tbz w2, #0, NextPixel + Write1: + ldr x4, [sp, #160] + str s29, [x4], #4 + str x4, [sp, #160] + + NextPixel: + ldr x2, [sp, #24] + add x1, x1, x2 + sub x3, x3, #1 + cmp x3, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp, #64] + ldp x21, x22, [sp, #80] + ldp x23, x24, [sp, #96] + ldp x25, x26, [sp, #112] + ldp x27, x28, [sp, #128] + ldp x29, x30, [sp, #144] + add sp, sp, #176 +ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S new file mode 100644 index 00000000..59923da4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwFp32Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Row(float* output_ptr, const float* input_ptr,const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp32Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +ble End + +mov x9, x0 +mov x12, #4 +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + + LoopDepth16In: + cmp x8, #16 + blt L4 + sub x8, x8, #16 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + cmp x8, #16 + blt LoopDepth16Out + LoopDepth16: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + sub x8, x8, #16 + cmp x8, #16 + bge LoopDepth16 + + LoopDepth16Out: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + L4: + cmp x8, #4 + blt L0 + + LoopDepth4: + ld1 {v0.4s}, [x6], #16 + ld1 {v2.4s}, [x7], #16 + ld1 {v16.4s}, [x0], #16 + fmla v16.4s, v0.4s, v2.4s + st1 {v16.4s}, [x9], #16 + sub x8, x8, #4 + cmp x8, #4 + bge LoopDepth4 + + L0: + cmp x8, #0 + beq Loop16LineEnd + + LoopDepth0: + ldr s0, [x6], #4 + ldr s1, [x7], #4 + ldr s2, [x0], #4 + fmul s0, s0, s1 + fadd s2, s2, s0 + str s2, [x9], #4 + subs x8, x8, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S new file mode 100644 index 00000000..2648795e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Center.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, +// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, +// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, +// int32_t *acc_min, int32_t *acc_max) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max +asm_function ConvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + + ldr x14, [sp, #240] // input_zp + ld1 {v19.8b}, [x14], #8 + + ldr x15, [sp, #248] // output_zp + ld1 {v20.4s}, [x15], #16 + ld1 {v21.4s}, [x15], #16 + + ldr x16, [sp, #256] // out_multiplier + ld1 {v22.4s}, [x16], #16 + ld1 {v23.4s}, [x16], #16 + + ldr x17, [sp, #264] // left_shift + ld1 {v24.4s}, [x17], #16 + ld1 {v25.4s}, [x17], #16 + + ldr x25, [sp, #272] // right shift + ld1 {v26.4s}, [x25], #16 + ld1 {v27.4s}, [x25], #16 + + ldr x19, [sp, #280] // acc_min + ld1 {v28.4s}, [x19], #16 + ld1 {v29.4s}, [x19], #16 + + ldr x20, [sp, #288] // acc_max + ld1 {v30.4s}, [x20], #16 + ld1 {v31.4s}, [x20], #16 + + ld1 {v17.4s}, [x3], #16 + ld1 {v18.4s}, [x3], #16 + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + + LoopW4: + mov x19, #4 + mul x19, x19, x11 + mov x25, #4 + mul x25, x25, x9 + + mov x16, x23 + mov x17, x2 + mov x20, x6 + + mov v0.16b, v17.16b + mov v1.16b, v18.16b + mov v2.16b, v17.16b + mov v3.16b, v18.16b + mov v4.16b, v17.16b + mov v5.16b, v18.16b + mov v6.16b, v17.16b + mov v7.16b, v18.16b + LoopKh4: + mov x25, x7 + mov x21, x16 + LoopKw4: + mov x22, x21 + ld1 {v16.8h}, [x17], #16 + + ld1 {v15.8b}, [x22], x11 + ssubl v14.8h, v15.8b, v19.8b + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + + ld1 {v13.8b}, [x22], x11 + ssubl v12.8h, v13.8b, v19.8b + smlal v2.4s, v12.4h, v16.4h + smlal2 v3.4s, v12.8h, v16.8h + + ld1 {v11.8b}, [x22], x11 + ssubl v10.8h, v11.8b, v19.8b + smlal v4.4s, v10.4h, v16.4h + smlal2 v5.4s, v10.8h, v16.8h + + ld1 {v9.8b}, [x22], x11 + ssubl v8.8h, v9.8b, v19.8b + smlal v6.4s, v8.4h, v16.4h + smlal2 v7.4s, v8.8h, v16.8h + + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw4 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh4 + + sqshl v0.4s, v0.4s, v24.4s + sqshl v1.4s, v1.4s, v25.4s + sqshl v2.4s, v2.4s, v24.4s + sqshl v3.4s, v3.4s, v25.4s + sqshl v4.4s, v4.4s, v24.4s + sqshl v5.4s, v5.4s, v25.4s + sqshl v6.4s, v6.4s, v24.4s + sqshl v7.4s, v7.4s, v25.4s + + sqrdmulh v0.4s, v0.4s, v22.4s + sqrdmulh v1.4s, v1.4s, v23.4s + sqrdmulh v2.4s, v2.4s, v22.4s + sqrdmulh v3.4s, v3.4s, v23.4s + sqrdmulh v4.4s, v4.4s, v22.4s + sqrdmulh v5.4s, v5.4s, v23.4s + sqrdmulh v6.4s, v6.4s, v22.4s + sqrdmulh v7.4s, v7.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + sqrshl v2.4s, v2.4s, v26.4s + sqrshl v3.4s, v3.4s, v27.4s + sqrshl v4.4s, v4.4s, v26.4s + sqrshl v5.4s, v5.4s, v27.4s + sqrshl v6.4s, v6.4s, v26.4s + sqrshl v7.4s, v7.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + add v1.4s, v1.4s, v21.4s + add v2.4s, v2.4s, v20.4s + add v3.4s, v3.4s, v21.4s + add v4.4s, v4.4s, v20.4s + add v5.4s, v5.4s, v21.4s + add v6.4s, v6.4s, v20.4s + add v7.4s, v7.4s, v21.4s + smax v0.4s, v0.4s, v28.4s + smax v1.4s, v1.4s, v29.4s + smax v2.4s, v2.4s, v28.4s + smax v3.4s, v3.4s, v29.4s + smax v4.4s, v4.4s, v28.4s + smax v5.4s, v5.4s, v29.4s + smax v6.4s, v6.4s, v28.4s + smax v7.4s, v7.4s, v29.4s + smin v0.4s, v0.4s, v30.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v30.4s + smin v3.4s, v3.4s, v31.4s + smin v4.4s, v4.4s, v30.4s + smin v5.4s, v5.4s, v31.4s + smin v6.4s, v6.4s, v30.4s + smin v7.4s, v7.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + sqxtn v4.4h, v4.4s + sqxtn v5.4h, v5.4s + sqxtn v6.4h, v6.4s + sqxtn v7.4h, v7.4s + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + sqxtn v4.8b, v4.8h + sqxtn v5.8b, v5.8h + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + + mov x16, x3 + add x17, x16, x9 + add x25, x17, x9 + add x21, x25, x9 + + st1 {v0.s}[0], [x16], #4 + st1 {v1.s}[0], [x16], #4 + st1 {v2.s}[0], [x17], #4 + st1 {v3.s}[0], [x17], #4 + st1 {v4.s}[0], [x25], #4 + st1 {v5.s}[0], [x25], #4 + st1 {v6.s}[0], [x21], #4 + st1 {v7.s}[0], [x21], #4 + + add x3, x3, x25 + add x23, x23, x19 + sub x24, x24, #4 + cmp x24, #0 + ble LoopWEnd + cmp x24, #4 + bge LoopW4 + + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v17.16b + mov v1.16b, v18.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v15.8b}, [x22], x13 + ssubl v14.8h, v15.8b, v19.8b + ld1 {v16.8h}, [x17], #16 + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + + sqshl v0.4s, v0.4s, v24.4s + sqrdmulh v0.4s, v0.4s, v22.4s + sqshl v1.4s, v1.4s, v25.4s + sqrdmulh v1.4s, v1.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + smax v0.4s, v0.4s, v28.4s + smin v0.4s, v0.4s, v30.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + add v1.4s, v1.4s, v21.4s + smax v1.4s, v1.4s, v29.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v1.4h, v1.4s + sqxtn v1.8b, v1.8h + + mov x17, x3 + st1 {v0.s}[0], [x17], #4 + st1 {v1.s}[0], [x17], #4 + add x3, x3, x9 + + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S new file mode 100644 index 00000000..8d678174 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v26.4s, w5 + dup v27.4s, w4 + dup v28.4s, w6 + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + cmp x2, #16 + blt LoopDepth8 + + LoopDepth16: + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + ld1 {v2.4s}, [x1], #16 + ld1 {v3.4s}, [x1], #16 + + cbz w5, RightShiftDepth16 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqshl v2.4s, v2.4s, v26.4s + sqshl v3.4s, v3.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + b AddZpDepth16 + + RightShiftDepth16: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + and v6.16b, v2.16b, v28.16b + sshr v6.4s, v6.4s, #31 + sqadd v2.4s, v2.4s, v6.4s + srshl v2.4s, v2.4s, v28.4s + and v7.16b, v3.16b, v28.16b + sshr v7.4s, v7.4s, #31 + sqadd v3.4s, v3.4s, v7.4s + srshl v3.4s, v3.4s, v28.4s + + AddZpDepth16: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + st1 {v2.s}[0], [x0], #4 + st1 {v3.s}[0], [x0], #4 + + sub x2, x2, #16 + cmp x2, #16 + bge LoopDepth16 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + cbz w5, RightShiftDepth8 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + b AddZpDepth8 + + RightShiftDepth8: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + + AddZpDepth8: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 00000000..7f14f7c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + ld1 {v2.4s}, [x5], #16 + ld1 {v3.4s}, [x5], #16 + + ld1 {v4.4s}, [x4], #16 + ld1 {v5.4s}, [x4], #16 + + sqshl v0.4s, v0.4s, v2.4s + sqshl v1.4s, v1.4s, v3.4s + + ld1 {v6.4s}, [x6], #16 + ld1 {v7.4s}, [x6], #16 + + sqrdmulh v0.4s, v0.4s, v4.4s + sqrdmulh v1.4s, v1.4s, v5.4s + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + and v17.16b, v1.16b, v7.16b + sshr v17.4s, v17.4s, #31 + sqadd v1.4s, v1.4s, v17.4s + srshl v1.4s, v1.4s, v7.4s + + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + ld1 {v2.4s}, [x5], #16 + + sqshl v0.4s, v0.4s, v2.4s + + ld1 {v4.4s}, [x4], #16 + sqrdmulh v0.4s, v0.4s, v4.4s + + ld1 {v6.4s}, [x6], #16 + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S new file mode 100644 index 00000000..5828d575 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvDwInt8Row.S @@ -0,0 +1,134 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels, +// x4: output_channel, x5: input_step, x6: input_zp +// +asm_function ConvDwInt8Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x10, x0 + +dup v31.8b, w6 + +LoopOutPixel: +mov x7, x1 +mov x8, x2 +mov x9, x4 + + LoopDepth16In: + cmp x9, #16 + blt L8 + sub x9, x9, #16 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + + cmp x9, #16 + blt LoopDepth16Out + LoopDepth16: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + sub x9, x9, #16 + cmp x9, #16 + bge LoopDepth16 + + LoopDepth16Out: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + L8: + cmp x9, #8 + blt L0 + + LoopDepth8: + ld1 {v0.8b}, [x7], #8 + ld1 {v2.8h}, [x8], #16 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + st1 {v16.4s, v17.4s}, [x10], #32 + + sub x9, x9, #8 + cmp x9, #8 + bge LoopDepth8 + + L0: + cmp x9, #0 + beq Loop16LineEnd + + LoopDepth0: + ldrsb w14, [x7], #1 + ldrsh w15, [x8], #2 + ldr w16, [x0], #4 + sub w14, w14, w6 + + sxth w14, w14 + madd w14, w14, w15, w16 + str w14, [x10], #4 + + subs x9, x9, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S new file mode 100644 index 00000000..9cead57c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvFp32Center.S @@ -0,0 +1,458 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: ic4, x11: in_sh_step, x12: in_sw_step, x13: in_kh_step, x14: in_kw_step +// x26: relu, x16: relu6 +asm_function ConvSwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x8, [sp, #208] + ldr x9, [sp, #216] + ldr x10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr x14, [sp, #256] + mul x15, x6, x7 + mul x15, x10, x15 + mov x16, #16 + mul x15, x15, x16 + + ld1 {v25.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x17, x1 + mov x28, x5 + mov x3, x0 + cmp x28, #8 + blt LoopW + cmp x28, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + mov v8.16b, v25.16b + mov v9.16b, v25.16b + mov v10.16b, v25.16b + mov v11.16b, v25.16b + mov v12.16b, v25.16b + mov v13.16b, v25.16b + mov v14.16b, v25.16b + mov v15.16b, v25.16b + LoopKh16: + mov x23, x7 + mov x24, x20 + LoopKw16: + mov x25, x24 + mov x27, x10 + LoopIc16: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v8.4s, v28.4s, v16.s[0] + fmla v9.4s, v28.4s, v17.s[0] + fmla v8.4s, v29.4s, v16.s[1] + fmla v9.4s, v29.4s, v17.s[1] + fmla v8.4s, v30.4s, v16.s[2] + fmla v9.4s, v30.4s, v17.s[2] + fmla v8.4s, v31.4s, v16.s[3] + fmla v9.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v10.4s, v28.4s, v18.s[0] + fmla v11.4s, v28.4s, v19.s[0] + fmla v10.4s, v29.4s, v18.s[1] + fmla v11.4s, v29.4s, v19.s[1] + fmla v10.4s, v30.4s, v18.s[2] + fmla v11.4s, v30.4s, v19.s[2] + fmla v10.4s, v31.4s, v18.s[3] + fmla v11.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v12.4s, v28.4s, v20.s[0] + fmla v13.4s, v28.4s, v21.s[0] + fmla v12.4s, v29.4s, v20.s[1] + fmla v13.4s, v29.4s, v21.s[1] + fmla v12.4s, v30.4s, v20.s[2] + fmla v13.4s, v30.4s, v21.s[2] + fmla v12.4s, v31.4s, v20.s[3] + fmla v13.4s, v31.4s, v21.s[3] + fmla v14.4s, v28.4s, v22.s[0] + fmla v15.4s, v28.4s, v23.s[0] + fmla v14.4s, v29.4s, v22.s[1] + fmla v15.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v22.s[2] + fmla v15.4s, v30.4s, v23.s[2] + fmla v14.4s, v31.4s, v22.s[3] + fmla v15.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc16 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw16 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh16 + ldr x16, [sp, #272] + cbnz x16, Relu616 + ldr x26, [sp, #264] + cbnz x26, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #16 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + blt LoopW + cmp x28, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + LoopKh8: + mov x23, x7 + mov x24, x20 + LoopKw8: + mov x25, x24 + mov x27, x10 + LoopIc8: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc8 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw8 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh8 + ldr x16, [sp, #272] + cbnz x16, Relu68 + ldr x26, [sp, #264] + cbnz x26, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #8 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + bge LoopW8 + LoopW: + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + LoopKh: + mov x23, x7 + mov x24, x20 + LoopKw: + mov x25, x24 + mov x27, x10 + LoopIc: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + fmla v0.4s, v28.4s, v16.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v0.4s, v31.4s, v16.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh + ldr x16, [sp, #272] + cbnz x16, Relu6 + ldr x26, [sp, #264] + cbnz x26, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x17, x17, x12 + subs x28, x28, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x11 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S new file mode 100644 index 00000000..7820c521 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x16Kernel.S @@ -0,0 +1,421 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v2.4s, v18.4s, v7.s[0] + fmla v3.4s, v19.4s, v7.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[1] + fmla v1.4s, v17.4s, v7.s[1] + fmla v2.4s, v18.4s, v7.s[1] + fmla v3.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v2.4s, v18.4s, v7.s[2] + fmla v3.4s, v19.4s, v7.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[3] + fmla v1.4s, v17.4s, v7.s[3] + fmla v2.4s, v18.4s, v7.s[3] + fmla v3.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.4s + fmla v1.4s, v17.4s, v5.4s + fmla v2.4s, v18.4s, v5.4s + fmla v3.4s, v19.4s, v5.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.4s + fmla v1.4s, v17.4s, v6.4s + fmla v2.4s, v18.4s, v6.4s + fmla v3.4s, v19.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.d}[0], [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x20, x7, LSL #1 + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + st1 {v2.4s}, [x21] + st1 {v3.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S new file mode 100644 index 00000000..7ed045f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW1x8Kernel.S @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v0.4s, v18.4s, v7.s[1] + fmla v1.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v0.4s, v18.4s, v7.s[3] + fmla v1.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + ld1 {v20.4s, v21.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v0.4s, v18.4s, v5.4s + fmla v1.4s, v19.4s, v5.4s + fmla v0.4s, v20.4s, v6.4s + fmla v1.4s, v21.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.2s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + b End + NC4HW4: + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S new file mode 100644 index 00000000..221ebcf0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x16Kernel.S @@ -0,0 +1,407 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S new file mode 100644 index 00000000..0d3be107 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW2x8Kernel.S @@ -0,0 +1,265 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S new file mode 100644 index 00000000..34706ef9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x16Kernel.S @@ -0,0 +1,533 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + + ldr x10, [sp, #128] + ldr x11, [sp, #136] + ldr x12, [sp, #144] + ldr x13, [sp, #152] + ldr x14, [sp, #160] + ldr x15, [sp, #168] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S new file mode 100644 index 00000000..afbdecf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW3x8Kernel.S @@ -0,0 +1,332 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S new file mode 100644 index 00000000..9d5400a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x16Kernel.S @@ -0,0 +1,662 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + fmla v12.4s, v28.4s, v27.s[0] + fmla v13.4s, v29.4s, v27.s[0] + fmla v14.4s, v30.4s, v27.s[0] + fmla v15.4s, v31.4s, v27.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + fmla v12.4s, v28.4s, v27.s[1] + fmla v13.4s, v29.4s, v27.s[1] + fmla v14.4s, v30.4s, v27.s[1] + fmla v15.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + fmla v12.4s, v28.4s, v27.s[2] + fmla v13.4s, v29.4s, v27.s[2] + fmla v14.4s, v30.4s, v27.s[2] + fmla v15.4s, v31.4s, v27.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + fmla v12.4s, v28.4s, v27.s[3] + fmla v13.4s, v29.4s, v27.s[3] + fmla v14.4s, v30.4s, v27.s[3] + fmla v15.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S new file mode 100644 index 00000000..0de222ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW4x8Kernel.S @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v6.4s, v28.4s, v27.s[0] + fmla v7.4s, v29.4s, v27.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + fmla v6.4s, v30.4s, v27.s[1] + fmla v7.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v6.4s, v28.4s, v27.s[2] + fmla v7.4s, v29.4s, v27.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + fmla v6.4s, v30.4s, v27.s[3] + fmla v7.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + add x22, x21, x7 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S new file mode 100644 index 00000000..11583d53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x16Kernel.S @@ -0,0 +1,457 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[3] + fmla v1.4s, v29.4s, v20.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v23.s[3] + fmla v13.4s, v29.4s, v23.s[3] + fmla v14.4s, v30.4s, v23.s[3] + fmla v15.4s, v31.4s, v23.s[3] + fmla v16.4s, v28.4s, v24.s[3] + fmla v17.4s, v29.4s, v24.s[3] + fmla v18.4s, v30.4s, v24.s[3] + fmla v19.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + fmax v16.4s, v16.4s, v24.4s + fmax v17.4s, v17.4s, v24.4s + fmax v18.4s, v18.4s, v24.4s + fmax v19.4s, v19.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + fmin v16.4s, v16.4s, v24.4s + fmin v17.4s, v17.4s, v24.4s + fmin v18.4s, v18.4s, v24.4s + fmin v19.4s, v19.4s, v24.4s + + WriteBack: + add x20, x0, x7 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0], #16 + st1 {v16.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20], #16 + st1 {v17.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21], #16 + st1 {v18.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22], #16 + st1 {v19.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S new file mode 100644 index 00000000..58181f0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/ConvSW5x8Kernel.S @@ -0,0 +1,308 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v23.s[2] + fmla v7.4s, v29.4s, v23.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v20.s[3] + fmla v1.4s, v31.4s, v20.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v23.s[3] + fmla v7.4s, v31.4s, v23.s[3] + fmla v8.4s, v30.4s, v24.s[3] + fmla v9.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + ld1 {v26.4s, v27.4s}, [x2], #32 + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + fmla v0.4s, v26.4s, v20.s[2] + fmla v1.4s, v27.4s, v20.s[2] + fmla v2.4s, v26.4s, v21.s[2] + fmla v3.4s, v27.4s, v21.s[2] + fmla v4.4s, v26.4s, v22.s[2] + fmla v5.4s, v27.4s, v22.s[2] + fmla v6.4s, v26.4s, v23.s[2] + fmla v7.4s, v27.4s, v23.s[2] + fmla v8.4s, v26.4s, v24.s[2] + fmla v9.4s, v27.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + + WriteBack: + add x20, x0, x7 + cmp x15, #13 + beq NC4HW4 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + st1 {v8.4s, v9.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v9.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S new file mode 100644 index 00000000..74723e98 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Border.S @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x3, #0 + beq End + cmp x4, #0 + beq End + ld1 {v1.4s}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.4s}, [x15] + ld1 {v2.4s}, [x16], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S new file mode 100644 index 00000000..1ef311ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwFp32Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4s}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4s}, [x19], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S new file mode 100644 index 00000000..299d3700 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4h}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4h}, [x19], #8 + smlal v0.4s, v1.4h, v2.4h + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S new file mode 100644 index 00000000..ff9d6a64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DeconvDwInt8Post.S @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// x0: dst, x1: output_buffer, x2: bias, x3: block_channel, x4: pixel_nums, x5: out_multiplier +// x6: left_shift, x7: right_shift, x8: out_zp, x9: acc_min, x10: acc_max + +asm_function DeconvDwInt8Post + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v25.4s}, [x2] + + dup v26.4s, w6 // left_shift + dup v27.4s, w5 // out_multiplier + dup v28.4s, w7 // right_shift + + ldr w8, [sp] + dup v29.4s, w8 // out_zp + ldr w9, [sp, #8] + dup v30.4s, w9 // acc_min + ldr w10, [sp, #16] + dup v31.4s, w10 // acc_max + + LoopCount: + ld1 {v0.4s}, [x1], #16 + add v0.4s, v0.4s, v25.4s + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], x3 + + sub x4, x4, #1 + cmp x4, #1 + bge LoopCount + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S new file mode 100644 index 00000000..ef3a39fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/DynamicGatherArm64.S @@ -0,0 +1,48 @@ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], #64 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S new file mode 100644 index 00000000..ba60f390 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); +// x0: output, x1: input, x2: weight, x3: ksize, x4: ic8, x5: oc4, x6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + dup v28.4s, wzr + mov v29.16b, v28.16b + mov v30.16b, v28.16b + mov v31.16b, v28.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + LoopOc: + mov x7, x3 + mov x8, x1 + + LoopKsize: + mov x9, x0 + INIT_ZERO + + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + // load weight + ld1 {v16.8h}, [x2], #16 + smull v24.4s, v16.4h, v0.h[0] + smull v25.4s, v16.4h, v1.h[0] + // load weight + ld1 {v17.8h}, [x2], #16 + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smull v26.4s, v16.4h, v2.h[0] + smull v27.4s, v16.4h, v3.h[0] + + subs x10, x4, #1 + beq LoopIcEnd + + LoopIc: + + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + // load weight + ld1 {v16.8h, v17.8h}, [x2], #32 + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v16.4h, v0.h[0] + smlal v25.4s, v16.4h, v1.h[0] + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + smlal v26.4s, v16.4h, v2.h[0] + smlal v27.4s, v16.4h, v3.h[0] + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + st1 {v24.4s}, [x9], x6 + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + st1 {v25.4s}, [x9], x6 + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + st1 {v26.4s}, [x9], x6 + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + st1 {v27.4s}, [x9], x6 + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + st1 {v28.4s}, [x9], x6 + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + st1 {v29.4s}, [x9], x6 + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + + st1 {v30.4s}, [x9], x6 + st1 {v31.4s}, [x9] + + subs x7, x7, #1 + add x0, x0, #16 + bne LoopKsize + + subs x5, x5, #1 + bne LoopOc + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S new file mode 100644 index 00000000..bd427335 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulFp32.S @@ -0,0 +1,252 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulFp32 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #4 // sizeof(float) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + dup v14.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth4_1x4 + sub w9, w9, #8 + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + cmp w9, #8 + blt Depth8_1x4_Loop_End + +Depth8_1x4_Loop: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + ld1 {v2.4s, v3.4s}, [x7], #32 + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + sub w9, w9, #8 + cmp w9, #8 + bge Depth8_1x4_Loop + +Depth8_1x4_Loop_End: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + +Depth4_1x4: + cmp w9, #4 + blt Depth1_1x4 + sub w9, w9, #4 + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x10], #16 + cmp w9, #4 + blt Depth4_1x4_Loop_End + +Depth4_1x4_Loop: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + ld1 {v1.4s}, [x7], #16 + fmla v13.4s, v4.4s, v0.4s + ld1 {v0.4s}, [x15], #16 + ld1 {v2.4s}, [x10], #16 + sub w9, w9, #4 + cmp w9, #4 + bge Depth4_1x4_Loop + +Depth4_1x4_Loop_End: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + fmla v13.4s, v4.4s, v0.4s + +Depth1_1x4: + cmp w9, #0 + beq End1x4 + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + ld1 {v1.s}[1], [x10], #4 + ld1 {v1.s}[2], [x11], #4 + ld1 {v1.s}[3], [x12], #4 + + fmla v14.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v15.4s, v10.4s, v11.4s + faddp v16.4s, v12.4s, v13.4s + faddp v17.4s, v15.4s, v16.4s + fadd v14.4s, v14.4s, v17.4s + + cbz x3, Act1x4 + ld1 {v15.4s}, [x3], #16 + fadd v14.4s, v14.4s, v15.4s // add bias + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v15.4s, #0x46, lsl #8 + fmin v14.4s, v14.4s, v15.4s + +Relu1x4: + dup v15.4s, wzr + fmax v14.4s, v14.4s, v15.4s + +Write1x4: + st1 {v14.4s}, [x2], #16 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + + +Loop1x1: + dup v4.4s, wzr + dup v5.4s, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth4_1x1 + + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + + fmla v4.4s, v2.4s, v0.4s + fmla v4.4s, v3.4s, v1.4s + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth4_1x1: + cmp w9, #4 + blt Depth1_1x1 + + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + + fmla v4.4s, v1.4s, v0.4s + sub w9, w9, #4 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + + fmla v5.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v6.4s, v4.4s, v4.4s + faddp v7.4s, v6.4s, v6.4s + fadd v7.4s, v7.4s, v5.4s + + cbz x3, Act1x1 + ld1 {v8.s}[0], [x3], #4 + fadd v7.4s, v7.4s, v8.4s // add bias + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v8.4s, #0x46, lsl #8 + fmin v7.4s, v7.4s, v8.4s + +Relu1x1: + dup v8.4s, wzr + fmax v7.4s, v7.4s, v8.4s + +Write1x1: + st1 {v7.s}[0], [x2], #4 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S new file mode 100644 index 00000000..058a807c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatVecMulPackFp32.S @@ -0,0 +1,198 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulPackFp32 + sub sp, sp, #16 + stp x29, x30, [sp] + + dup v1.2d, xzr + mov w7, #6 + dup v2.4s, w7 + scvtf v2.4s, v2.4s + subs w6, w6, #8 + blt Loop1xNStart + Loop1x8Start: + bl Compute1x8Unit + st1 {v24.4s, v25.4s}, [x2], #32 + subs w6, w6, #8 + bge Loop1x8Start + + Loop1xNStart: + add w6, w6, #8 + cbz w6, End + subs w6, w6, #4 + ble Loop1x4Start + bl Compute1x8Unit + st1 {v24.4s}, [x2], #16 + st1 {v25.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v25.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v25.s}[2], [x2] + b End + + Loop1x4Start: + add w6, w6, #4 + cbz w6, End + bl Compute1x4Unit + st1 {v24.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v24.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v24.s}[2], [x2], #4 + cmp w6, #3 + beq End + st1 {v24.s}[3], [x2], #4 + b End + + Compute1x8Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + cbz x3, Compute1x8Enter + ld1 {v24.4s, v25.4s}, [x3], #32 + Compute1x8Enter: + subs w8, w8, #4 + blt Compute1x8Tail + Compute1x8: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v27.4s, v19.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v29.4s, v21.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + fmla v31.4s, v23.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x8 + Compute1x8Tail: + add w8, w8, #4 + cbz w8, Compute1x8UnionTail + Compute1x8DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s, v17.4s}, [x1], #32 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x8DepthTail + Compute1x8UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v25.4s, v25.4s, v27.4s + fadd v28.4s, v28.4s, v30.4s + fadd v29.4s, v29.4s, v31.4s + fadd v24.4s, v24.4s, v28.4s + fadd v25.4s, v25.4s, v29.4s + Act1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Return1x8 + Relu61x8: + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + b Return1x8 + Relu1x8: + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + Return1x8: + ret + + Compute1x4Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + cbz x3, Compute1x4Enter + ld1 {v24.4s}, [x3] + Compute1x4Enter: + subs w8, w8, #4 + blt Compute1x4Tail + Compute1x4: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x4 + Compute1x4Tail: + add w8, w8, #4 + cbz w8, Compute1x4UnionTail + Compute1x4DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s}, [x1] + add x1, x1, #32 + fmla v24.4s, v16.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x4DepthTail + Compute1x4UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v28.4s, v28.4s, v30.4s + fadd v24.4s, v24.4s, v28.4s + Act1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Return1x4 + Relu61x4: + fmin v24.4s, v24.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + b Return1x8 + Relu1x4: + fmax v24.4s, v24.4s, v1.4s + Return1x4: + ret + + End: + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S new file mode 100644 index 00000000..3c648444 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32.S @@ -0,0 +1,787 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: c8_nhwc_c4 + +asm_function MatmulFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x9, [sp, #152] + ldr x14, [sp, #160] + + mov w19, #32 // sizeof(float) * 8 + mul w15, w5, w19 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x19, #4 + ldr x17, [sp, #144] + cbz x14, NoWinoSteps + mul x8, x7, x17 + mov x11, #8 + mul x11, x11, x17 + mul x8, x8, x19 + mul x11, x11, x19 +NoWinoSteps: + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +LoopStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + beq LoopEnd + +Loop: + ld1 {v0.4s}, [x12], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ld1 {v1.4s}, [x12], #16 + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ld1 {v3.4s}, [x16], #16 + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ld1 {v4.4s}, [x16], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v2.4s}, [x12], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + bgt Loop + +LoopEnd: + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + +Bias: + cbz x3, Activation + ld1 {v0.4s}, [x3], #16 + ld1 {v1.4s}, [x3] + sub x3, x3, #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + mov w13, #6 + dup v2.4s, w13 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + +Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + +Write: + cbnz x14, WriteWino + cbz x9, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + str s8, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + str s10, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + str s12, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + str s14, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + str s16, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + str s18, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + str s20, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + str s22, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + str s24, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + str s26, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + str s28, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + str s30, [x19] + add x19, x19, x17 + b WriteEnd +Write2: + dup s9, v8.s[1] + stp s8, s9, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + dup s11, v10.s[1] + stp s10, s11, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + dup s13, v12.s[1] + stp s12, s13, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + dup s15, v14.s[1] + stp s14, s15, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + dup s17, v16.s[1] + stp s16, s17, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + dup s19, v18.s[1] + stp s18, s19, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + dup s21, v20.s[1] + stp s20, s21, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + dup s23, v22.s[1] + stp s22, s23, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + dup s25, v24.s[1] + stp s24, s25, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + dup s27, v26.s[1] + stp s26, s27, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + dup s29, v28.s[1] + stp s28, s29, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + b WriteEnd +Write3: + add x13, x19, #8 + dup s9, v8.s[1] + stp s8, s9, [x19] + add x19, x19, x17 + st1 {v8.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s11, v10.s[1] + stp s10, s11, [x19] + add x19, x19, x17 + st1 {v10.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s13, v12.s[1] + stp s12, s13, [x19] + add x19, x19, x17 + st1 {v12.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s15, v14.s[1] + stp s14, s15, [x19] + add x19, x19, x17 + st1 {v14.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd + dup s17, v16.s[1] + stp s16, s17, [x19] + add x19, x19, x17 + st1 {v16.s}[2], [x13], x17 + cmp w10, #5 + beq WriteEnd + dup s19, v18.s[1] + stp s18, s19, [x19] + add x19, x19, x17 + st1 {v18.s}[2], [x13], x17 + cmp w10, #6 + beq WriteEnd + dup s21, v20.s[1] + stp s20, s21, [x19] + add x19, x19, x17 + st1 {v20.s}[2], [x13], x17 + cmp w10, #7 + beq WriteEnd + dup s23, v22.s[1] + stp s22, s23, [x19] + add x19, x19, x17 + st1 {v22.s}[2], [x13], x17 + cmp w10, #8 + beq WriteEnd + dup s25, v24.s[1] + stp s24, s25, [x19] + add x19, x19, x17 + st1 {v24.s}[2], [x13], x17 + cmp w10, #9 + beq WriteEnd + dup s27, v26.s[1] + stp s26, s27, [x19] + add x19, x19, x17 + st1 {v26.s}[2], [x13], x17 + cmp w10, #10 + beq WriteEnd + dup s29, v28.s[1] + stp s28, s29, [x19] + add x19, x19, x17 + st1 {v28.s}[2], [x13], x17 + cmp w10, #11 + beq WriteEnd + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + st1 {v30.s}[2], [x13] + b WriteEnd +Write4: + st1 {v8.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + str s9, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + str s11, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + str s13, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + str s15, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + str s17, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + str s19, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + str s21, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + str s23, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + str s25, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + str s27, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + str s29, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + str s31, [x13] + b WriteEnd +Write6: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + b WriteEnd +Write7: + add x13, x19, #16 + add x16, x19, #24 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + add x13, x13, x17 + st1 {v9.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + add x13, x13, x17 + st1 {v11.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + add x13, x13, x17 + st1 {v13.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + add x13, x13, x17 + st1 {v15.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + add x13, x13, x17 + st1 {v17.s}[2], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + add x13, x13, x17 + st1 {v19.s}[2], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + add x13, x13, x17 + st1 {v21.s}[2], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + add x13, x13, x17 + st1 {v23.s}[2], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + add x13, x13, x17 + st1 {v25.s}[2], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + add x13, x13, x17 + st1 {v27.s}[2], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + add x13, x13, x17 + st1 {v29.s}[2], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + add x13, x13, x17 + st1 {v31.s}[2], [x16], x17 + b WriteEnd +WriteC8: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + b WriteEnd +WriteWino: + st1 {v8.4s, v9.4s}, [x19], x8 + st1 {v10.4s, v11.4s}, [x19], x8 + st1 {v12.4s, v13.4s}, [x19], x8 + st1 {v14.4s, v15.4s}, [x19], x8 + st1 {v16.4s, v17.4s}, [x19], x8 + st1 {v18.4s, v19.4s}, [x19], x8 + st1 {v20.4s, v21.4s}, [x19], x8 + st1 {v22.4s, v23.4s}, [x19], x8 + st1 {v24.4s, v25.4s}, [x19], x8 + st1 {v26.4s, v27.4s}, [x19], x8 + st1 {v28.4s, v29.4s}, [x19], x8 + st1 {v30.4s, v31.4s}, [x19], x8 + b WriteEnd +Write8: + st1 {v8.4s, v9.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x19], x17 + +WriteEnd: + subs w10, w10, #12 // lhs row - 12 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + cbz x3, NoBiasStep + add x3, x3, #32 // bias ptr + stride +NoBiasStep: + cbnz x14, WinoDstStep + cbz x9, NoDstStep + add x2, x2, #32 // dst ptr + stride + b NoDstStep +WinoDstStep: + add x2, x2, x11 +NoDstStep: + bgt L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S new file mode 100644 index 00000000..abaf79e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32Opt.S @@ -0,0 +1,1669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64Opt + sub sp, sp, #160 + + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne RowStart + mov x20, x2 +RowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b Write + + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + // add x20, x11, x8 + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + // add x20, x11, x8 + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + // add x20, x11, x8 + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + + C4Write4: + add x20, x11, x8 + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + b WriteEnd + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S new file mode 100644 index 00000000..21dc91ca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow12.S @@ -0,0 +1,1229 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow12 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow + mov x20, x2 +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S new file mode 100644 index 00000000..9798eabd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow4.S @@ -0,0 +1,597 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow4 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #16 // 4 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 64, otherwise 4 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #64 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow4 + mov x20, x2 +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + st1 {v14.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol4 + + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S new file mode 100644 index 00000000..998b1e93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulFp32OptRow8.S @@ -0,0 +1,911 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow8 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #32 // 8 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 128, otherwise 8 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #128 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow8 + mov x20, x2 +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // weight packed 8, only hold place + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // only hold place + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + st1 {v22.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol8 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopCol8 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S new file mode 100644 index 00000000..10a74163 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8.S @@ -0,0 +1,420 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int row, int col, int stride, int filter_peroc) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8Neon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr w8, [sp, #208] + ldr w9, [sp, #216] + ldr w10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr w14, [sp, #256] + ldr w15, [sp, #264] + ldr w24, [sp, #272] + ldr w27, [sp, #280] + + mov w17, #4 // sizeof(int8)*4 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + mov w17, #1 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col4 + beq End1 + + mov w16, w3 // reset a row4 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x28], #16 + ld1 {v5.16b}, [x28], #16 + ld1 {v6.16b}, [x28], #16 + ld1 {v7.16b}, [x28], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w20, w20, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + cmp w27, #0 + beq PerTLoad +PerCLoad: + ld1 {v20.4s}, [x6], #16 + ld1 {v21.4s}, [x6], #16 + ld1 {v22.4s}, [x6], #16 + ld1 {v23.4s}, [x6], #16 + + ld1 {v13.4s}, [x12] + ld1 {v12.4s}, [x11] + ld1 {v11.4s}, [x13] + b Apply + +PerTLoad: + ld1 {v14.4s}, [x22], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + + ld1 {v14.s}[0], [x12] + dup v13.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v12.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v11.4s, v14.s[0] + b Apply + +Apply: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + // Apply left shift + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + // Apply right shift + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + // Add the destination zero point + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + // Apply the act_min bound + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + // Apply the act_min bound + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + // int32 -> int16 + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + // int16 -> int8 + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp w23, #4 + blt Write // if rows < 4 + cmp w15, #4 + blt Write // if cols < 4 + + st1 {v15.s}[0], [x2], x24 + st1 {v15.s}[1], [x2], x24 + st1 {v15.s}[2], [x2], x24 + st1 {v15.s}[3], [x2], x24 + b Endwrite + +Write: + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol4: + st1 {v15.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.s}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.s}[2], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.s}[3], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + st1 {v15.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + st1 {v15.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + st1 {v15.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + st1 {v15.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol1: + st1 {v15.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.b}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.b}[8], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.b}[12], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #4 // a row4 counter - 4 + sub w23, w23, #4 // a row counter - 4 + b L2 + +End2: + sub w4, w4, #4 // b col4 counter - 4 + sub w15, w15, #4 // b col counter - 4 + add x1, x1, x21 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + add x25, x25, #4 // output + stride(4 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #16 + add x11, x11, #16 + add x13, x13, #16 +PerTEnd2: + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S new file mode 100644 index 00000000..5fbb9825 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulInt8Opt.S @@ -0,0 +1,356 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row4 +// x4: col4 +// x5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr w8, [sp, #224] + ldr w9, [sp, #232] + ldr w10, [sp, #240] + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + ldr x14, [sp, #272] + ldr x15, [sp, #280] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #288] // reload filter_zp + + LoopCol: + mov x25, x6 // reload a_sums ptr + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + LoopDepth: + ld1 {v0.16b, v1.16b}, [x20], #32 + ld1 {v4.16b, v5.16b}, [x16], #32 + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + ld1 {v6.16b, v7.16b}, [x16], #32 + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + ld1 {v2.16b, v3.16b}, [x20], #32 + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs x21, x21, #16 // depth - 16 + bgt LoopDepth + + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + Bias: + cbz x7, NoBias + ld1 {v15.4s}, [x29], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + NoBias: + ld1r {v20.4s}, [x25], #4 + ld1r {v21.4s}, [x25], #4 + ld1r {v22.4s}, [x25], #4 + ld1r {v23.4s}, [x25], #4 + cbz x15, ApplySum + + ld1 {v14.4s}, [x28], #16 + mul v20.4s, v20.4s, v14.4s + mul v21.4s, v21.4s, v14.4s + mul v22.4s, v22.4s, v14.4s + mul v23.4s, v23.4s, v14.4s + + ApplySum: + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + cbnz x15, PerCLoad + + ld1r {v13.4s}, [x12] + ld1r {v12.4s}, [x11] + ld1r {v11.4s}, [x13] + b Quantize + + PerCLoad: + ld1 {v13.4s}, [x12], #16 + ld1 {v12.4s}, [x11], #16 + ld1 {v11.4s}, [x13], #16 + + Quantize: + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + b Write4 + + Write1: + add x27, x27, #1 + st1 {v15.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.b}[4], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.b}[8], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.b}[12], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v15.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v15.h}[0], [x19], x14 + st1 {v15.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + st1 {v15.b}[6], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + st1 {v15.b}[10], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + st1 {v15.b}[14], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v15.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.s}[1], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.s}[2], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.s}[3], [x19], x14 + + WriteEnd: + subs x17, x17, #4 + bgt LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S new file mode 100644 index 00000000..01f29170 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulR4Int8.S @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S new file mode 100644 index 00000000..ef5b39a4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/MatmulWinogradFp32.S @@ -0,0 +1,183 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel, x7: c4_channel +asm_function MatrixMultiplyWinograd + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.4s}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + mov x8, #4 + mul x10, x5, x8 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 4 + mul x21, x13, x4 // in_channel * k + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x7 // x22 * c4_channel + add x20, x2, x22 // dst + offset + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC16: + mov x12, x14 + mov x9, x4 // new_k + dup v5.4s, wzr + dup v6.4s, wzr + dup v7.4s, wzr + dup v8.4s, wzr + LoopK16: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + fmla v7.4s, v2.4s, v4.s[0] + fmla v8.4s, v3.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v8.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.4s, wzr + dup v6.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.4s, v1.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 64B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4s}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr s0, [x16] + add x16, x16, x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK + Write1: + str s5, [x20], #4 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #4 // ptr add 4B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #4 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + EndLoopM: + ld1 {v8.4s}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S new file mode 100644 index 00000000..e88eab25 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC4.S @@ -0,0 +1,316 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x12 oc_stride +// x14 x15 write loop tmp buf +// v26 relu6 #6; v27 relu #0 +// w10 oc4 loop control +// w13 hw loop control + + +asm_function WinogradPostFuncBiasReluC4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + mov x10, #4 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v16.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v16.4s + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_8x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_8x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + st1 {v4.4s}, [x15], x12 + st1 {v5.4s}, [x15], x12 + st1 {v6.4s}, [x15], x12 + st1 {v7.4s}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s +Relu_4x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s +Write_4x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S new file mode 100644 index 00000000..99213447 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncBiasReluC8.S @@ -0,0 +1,553 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// x14 x15 weite loop tmp buf +// x16 relu6 #6; x17 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + +Loop_8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + fadd v8.4s, v8.4s, v16.4s + fadd v9.4s, v9.4s, v17.4s + fadd v10.4s, v10.4s, v16.4s + fadd v11.4s, v11.4s, v17.4s + fadd v12.4s, v12.4s, v16.4s + fadd v13.4s, v13.4s, v17.4s + fadd v14.4s, v14.4s, v16.4s + fadd v15.4s, v15.4s, v17.4s + + cmp x7, #3 + beq Relu6_8x8 + cmp x7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s +Relu_8x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s +Write_8x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + st1 {v8.4s, v9.4s}, [x15], x6 + st1 {v10.4s, v11.4s}, [x15], x6 + st1 {v12.4s, v13.4s}, [x15], x6 + st1 {v14.4s, v15.4s}, [x15], x6 + b Loop_8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + + cmp x7, #3 + beq Relu6_4x8 + cmp x7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_4x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_4x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + +Loop_1x8: + cmp x7, #3 + beq Relu6_1x8 + cmp x7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp x7, #3 + beq Loop_C1_4_Relu6 + cmp x7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_5_Relu6 + cmp x7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_6_Relu6 + cmp x7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x15, x0, #16 + add x14, x0, #24 + cmp x7, #3 + beq Loop_C1_7_Relu6 + cmp x7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Write + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S new file mode 100644 index 00000000..71c44685 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PostFuncInt8C4Neon64.S @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, +// size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, +// int32_t zp, int32_t mini, int32_t maxi); +// x0 in +// x1 bias +// x2 out +// x3 oc4div +// x4 oc4res +// x5 plane +// x6 stride +// x7 multiplier +// x8 left_shift +// x9 right_shift +// x10 zp +// x11 mini +// x12 maxi + +// v0 ~ v15 value +// x24 x25 write loop tmp buf + + +// v16 bias data + +// v26 multiplier +// v27 left_shift +// v28 right_shift +// v29 zp +// v30 min +// v31 max + +// w15 oc4 loop control +// w16 hw loop control + +asm_function PostFuncInt8C4Neon64 + sub sp, sp, #16 + stp x24, x25, [sp] + + ldr w8, [sp, #16] + ldr w9, [sp, #24] + ldr w10, [sp, #32] + ldr w11, [sp, #40] + ldr w12, [sp, #48] + ldr w13, [sp, #56] + + dup v26.4s, w7 + dup v27.4s, w8 + dup v28.4s, w9 + dup v29.4s, w10 + dup v30.4s, w11 + dup v31.4s, w12 + + mov x15, #0 + +Loop_C4: + cmp x15, x3 + beq Loop_C1 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add w15, w15, #4 + mov w16, w5 + ld1 {v16.4s}, [x1], #16 + +Loop_4x4: + cmp x16, #4 + blt Loop_1x4 + sub x16, x16, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + + add v0.4s, v0.4s, v16.4s + add v1.4s, v1.4s, v16.4s + add v2.4s, v2.4s, v16.4s + add v3.4s, v3.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqshl v1.4s, v1.4s, v27.4s + sqshl v2.4s, v2.4s, v27.4s + sqshl v3.4s, v3.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + sqrdmulh v1.4s, v1.4s, v26.4s + sqrdmulh v2.4s, v2.4s, v26.4s + sqrdmulh v3.4s, v3.4s, v26.4s + and v4.16b, v28.16b, v0.16b + and v5.16b, v28.16b, v1.16b + and v6.16b, v28.16b, v2.16b + and v7.16b, v28.16b, v3.16b + sshr v4.4s, v4.4s, #31 + sshr v5.4s, v5.4s, #31 + sshr v6.4s, v6.4s, #31 + sshr v7.4s, v7.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + sqadd v1.4s, v1.4s, v5.4s + sqadd v2.4s, v2.4s, v6.4s + sqadd v3.4s, v3.4s, v7.4s + srshl v0.4s, v0.4s, v28.4s + srshl v1.4s, v1.4s, v28.4s + srshl v2.4s, v2.4s, v28.4s + srshl v3.4s, v3.4s, v28.4s + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + sqxtn v4.4h, v0.4s + sqxtn v5.4h, v1.4s + sqxtn v6.4h, v2.4s + sqxtn v7.4h, v3.4s + sqxtn v0.8b, v4.8h + sqxtn v1.8b, v5.8h + sqxtn v2.8b, v6.8h + sqxtn v3.8b, v7.8h + + st1 {v0.s}[0], [x25], x6 + st1 {v1.s}[0], [x25], x6 + st1 {v2.s}[0], [x25], x6 + st1 {v3.s}[0], [x25], x6 + b Loop_4x4 + + +Loop_1x4: + cmp x16, #0 + beq Loop_C4 + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.s}[0], [x25], x6 + b Loop_1x4 + +Loop_C1: + cmp x4, #0 + beq End + mov x16, x5 + ld1 {v16.4s}, [x1], #16 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add x24, x25, #2 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.b}[0], [x25], x6 + b Loop_C1_1 + + +Loop_C1_2: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + b Loop_C1_2 + + +Loop_C1_3: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + st1 {v0.b}[2], [x24], x6 + b Loop_C1_3 + + +End: + ldp x24, x25, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S new file mode 100644 index 00000000..53b6ec5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Peroc.S @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4, +// size_t oc_res4, size_t stride); + +// x0 src +// x1 sum +// x2 zp +// w3 hw4 +// w4 ic16 +// w5 oc_div4 +// w6 oc_res4 +// w7 stride + +asm_function PreSum4x16Int8Peroc + mov w8, #0 + +RowLoop: + cmp w8, w3 + beq End + add w8, w8, #4 + dup v16.4s, wzr + mov w9, #0 + mov x16, x2 + +Sum: + cmp w9, w4 + beq Mul + add w9, w9, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b Sum + +Mul: + mov x12, x1 + add x1, x1, #64 + mov w9, #0 + + dup v1.4s, v16.s[0] + dup v2.4s, v16.s[1] + dup v3.4s, v16.s[2] + dup v4.4s, v16.s[3] + +WriteOc4: + cmp w9, w5 + beq OcRes4 + add w9, w9, #4 + ld1 {v5.4s}, [x16], #16 + + mul v16.4s, v5.4s, v1.4s + mul v17.4s, v5.4s, v2.4s + mul v18.4s, v5.4s, v3.4s + mul v19.4s, v5.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + add x12, x12, x7 + b WriteOc4 + +OcRes4: + cmp w6, #0 + beq RowLoop + dup v15.4s, wzr + cmp w6, #1 + beq OcRes4_1 + cmp w6, #2 + beq OcRes4_2 + cmp w6, #3 + beq OcRes4_3 + +OcRes4_1: + ld1 {v15.s}[0], [x16] + b OcRes4End + +OcRes4_2: + ld1 {v15.d}[0], [x16] + b OcRes4End + +OcRes4_3: + ld1 {v15.d}[0], [x16] + add x16, x16, #8 + ld1 {v15.s}[2], [x16] + b OcRes4End + +OcRes4End: + mul v16.4s, v15.4s, v1.4s + mul v17.4s, v15.4s, v2.4s + mul v18.4s, v15.4s, v3.4s + mul v19.4s, v15.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + b RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S new file mode 100644 index 00000000..1e7f0709 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/PreSum4x16Int8Pert.S @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp); + +// x0 src +// x1 dst +// w2 row4 +// w3 co16 +// w4 filter_zp + +asm_function PreSum4x16Int8Pert + dup v17.4s, w4 + mov w5, #0 + +RowLoop: + cmp w5, w2 + beq End + add w5, w5, #4 + dup v16.4s, wzr + mov w6, #0 + +CalLoop: + cmp w6, w3 + beq Write + add w6, w6, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b CalLoop + +Write: + mul v16.4s, v16.4s, v17.4s + st1 {v16.4s}, [x1], #16 + beq RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S new file mode 100644 index 00000000..e2317a70 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/SPMM8x8Fp32.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, +// const float *bias, ActType act_type, size_t out_stride); +// x0: a +// x1: b +// x2: nnz +// x3: dmap +// x4: c +// x5: bias +// w6: act_type +// x7: out_stride + +// wdata tmp w8 +// loop_oc_count w9 +// loop_nnz_count w10 +// dmap tmp w11 +// a_ptr +// 8 x 1 fp32 A v0-v1 +// fp32 B-value v2 +// uint32 B-NNZ x9 +// uint32 B-INDEX x10 +// 4 MIN v3 +// 4 MAX v4 +// 2 vacc v5-v6 +// 8 x 8 fp32 C v16-v31 + +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +asm_function SPMM8x8Fp32 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + // init output with bias + ldr w8, [x5], #4 + dup v16.4s, w8 + dup v17.4s, w8 + ldr w8, [x5], #4 + dup v18.4s, w8 + dup v19.4s, w8 + ldr w8, [x5], #4 + dup v20.4s, w8 + dup v21.4s, w8 + ldr w8, [x5], #4 + dup v22.4s, w8 + dup v23.4s, w8 + ldr w8, [x5], #4 + dup v24.4s, w8 + dup v25.4s, w8 + ldr w8, [x5], #4 + dup v26.4s, w8 + dup v27.4s, w8 + ldr w8, [x5], #4 + dup v28.4s, w8 + dup v29.4s, w8 + ldr w8, [x5] + dup v30.4s, w8 + dup v31.4s, w8 + + // OC 0 + ldr w10, [x2], #4 // load nnz + cmp w10, #0 + beq OC_1 +LOOP_NNZ0: + ldr x11, [x3], #8 // load dmap + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] // load inputs + ldr w8, [x1], #4 // load weight + dup v2.4s, w8 + // matmul + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ0 + +OC_1: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_2 +LOOP_NNZ1: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v18.4s, v0.4s, v2.4s + fmla v19.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ1 + +OC_2: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_3 +LOOP_NNZ2: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v20.4s, v0.4s, v2.4s + fmla v21.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ2 + +OC_3: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_4 +LOOP_NNZ3: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v22.4s, v0.4s, v2.4s + fmla v23.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ3 + +OC_4: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_5 +LOOP_NNZ4: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v24.4s, v0.4s, v2.4s + fmla v25.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ4 + +OC_5: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_6 +LOOP_NNZ5: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v26.4s, v0.4s, v2.4s + fmla v27.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ5 + +OC_6: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_7 +LOOP_NNZ6: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v28.4s, v0.4s, v2.4s + fmla v29.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ6 + +OC_7: + ldr w10, [x2], #4 + cmp w10, #0 + beq REORDER_OUT +LOOP_NNZ7: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v30.4s, v0.4s, v2.4s + fmla v31.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ7 + + // reorder output +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +// v0[0] v0[1] v0[2] v0[3] v1[0] v1[1] v1[2] v1[3] +// v2[0] v2[1] v2[2] v2[3] v3[0] v3[1] v3[2] v3[3] +// v4[0] v4[1] v4[2] v4[3] v5[0] v5[1] v5[2] v5[3] +// v6[0] v6[1] v6[2] v6[3] v7[0] v7[1] v7[2] v7[3] +// v8[0] v8[1] v8[2] v8[3] v9[0] v9[1] v9[2] v9[3] +// v10[0] v10[1] v10[2] v10[3] v11[0] v11[1] v11[2] v11[3] +// v12[0] v12[1] v12[2] v12[3] v13[0] v13[1] v13[2] v13[3] +// v14[0] v14[1] v14[2] v14[3] v15[0] v15[1] v15[2] v15[3] + +REORDER_OUT: + zip1 v1.4s, v16.4s, v18.4s + zip2 v3.4s, v16.4s, v18.4s + zip1 v9.4s, v17.4s, v19.4s + zip2 v11.4s, v17.4s, v19.4s + zip1 v5.4s, v20.4s, v22.4s + zip2 v7.4s, v20.4s, v22.4s + zip1 v13.4s, v21.4s, v23.4s + zip2 v15.4s, v21.4s, v23.4s + trn1 v0.2d, v1.2d, v5.2d + trn2 v2.2d, v1.2d, v5.2d + trn1 v4.2d, v3.2d, v7.2d + trn2 v6.2d, v3.2d, v7.2d + trn1 v8.2d, v9.2d, v13.2d + trn2 v10.2d, v9.2d, v13.2d + trn1 v12.2d, v11.2d, v15.2d + trn2 v14.2d, v11.2d, v15.2d + + zip1 v16.4s, v24.4s, v26.4s + zip2 v17.4s, v24.4s, v26.4s + zip1 v20.4s, v25.4s, v27.4s + zip2 v21.4s, v25.4s, v27.4s + zip1 v18.4s, v28.4s, v30.4s + zip2 v19.4s, v28.4s, v30.4s + zip1 v22.4s, v29.4s, v31.4s + zip2 v23.4s, v29.4s, v31.4s + trn1 v1.2d, v16.2d, v18.2d + trn2 v3.2d, v16.2d, v18.2d + trn1 v5.2d, v17.2d, v19.2d + trn2 v7.2d, v17.2d, v19.2d + trn1 v9.2d, v20.2d, v22.2d + trn2 v11.2d, v20.2d, v22.2d + trn1 v13.2d, v21.2d, v23.2d + trn2 v15.2d, v21.2d, v23.2d + +WRITE_OUT: + st1 {v0.4s, v1.4s}, [x4], x7 + st1 {v2.4s, v3.4s}, [x4], x7 + st1 {v4.4s, v5.4s}, [x4], x7 + st1 {v6.4s, v7.4s}, [x4], x7 + st1 {v8.4s, v9.4s}, [x4], x7 + st1 {v10.4s, v11.4s}, [x4], x7 + st1 {v12.4s, v13.4s}, [x4], x7 + st1 {v14.4s, v15.4s}, [x4] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S new file mode 100644 index 00000000..dfb70710 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/TiledC4MatmulFp32.S @@ -0,0 +1,279 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t ic4, size_t cal_num, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #4 //sizeof(float) +mul x3, x3, x7 +mov x7, #64 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmul v18.4s, v8.4s, v2.s[0] + fmul v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmul v20.4s, v8.4s, v4.s[0] + fmul v21.4s, v8.4s, v5.s[0] + fmul v22.4s, v8.4s, v6.s[0] + fmul v23.4s, v8.4s, v7.s[0] + fmul v24.4s, v12.4s, v0.s[0] + fmul v25.4s, v12.4s, v1.s[0] + fmul v26.4s, v12.4s, v2.s[0] + fmul v27.4s, v12.4s, v3.s[0] + fmul v28.4s, v12.4s, v4.s[0] + fmul v29.4s, v12.4s, v5.s[0] + fmul v30.4s, v12.4s, v6.s[0] + fmul v31.4s, v12.4s, v7.s[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #128 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #128 + prfm pldl1keep, [x8, #128] + prfm pldl1keep, [x8, #192] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + fmla v29.4s, v15.4s, v5.s[3] + fmla v30.4s, v15.4s, v6.s[3] + fmla v31.4s, v15.4s, v7.s[3] + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + fmla v24.4s, v12.4s, v0.s[0] + fmla v25.4s, v12.4s, v1.s[0] + fmla v26.4s, v12.4s, v2.s[0] + fmla v27.4s, v12.4s, v3.s[0] + fmla v28.4s, v12.4s, v4.s[0] + fmla v29.4s, v12.4s, v5.s[0] + fmla v30.4s, v12.4s, v6.s[0] + fmla v31.4s, v12.4s, v7.s[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + add x7, x0, #64 + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x3 + fmla v29.4s, v15.4s, v5.s[3] + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x3 + fmla v30.4s, v15.4s, v6.s[3] + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x3 + mov x2, x6 + fmla v31.4s, v15.4s, v7.s[3] + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S new file mode 100644 index 00000000..f79abfc5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransLeft.S @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeft +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6:length + +sub sp, sp, #32 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + ld1 {v0.s}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4s}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x14], #16 + fmla v0.4s, v1.4s, v31.4s + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #4 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #32 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S new file mode 100644 index 00000000..29907d19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm64/WinogradTransRight.S @@ -0,0 +1,160 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRight +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + ld1 {v0.s}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4s}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x17], #16 + fmla v0.4s, v1.4s, v31.4s + + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #4 //sizeof(float) + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S new file mode 100644 index 00000000..e0121e93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float16Tofloat32.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global Float16ToFloat32 +#ifndef __APPLE__ + .type Float16ToFloat32, %function +#endif + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// r0: input, r1: output, r2: number +Float16ToFloat32: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.16 {q0, q1}, [r0]! + vcvt.f32.f16 q3, d0 + vcvt.f32.f16 q4, d1 + vcvt.f32.f16 q5, d2 + vst1.32 {q3, q4}, [r1]! + vcvt.f32.f16 q6, d3 + subs r2, r2, #16 + vst1.32 {q5, q6}, [r1]! + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.16 {q0}, [r0]! + vcvt.f32.f16 q1, d0 + vcvt.f32.f16 q2, d1 + vst1.32 {q1, q2}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr.16 s0, [r0] + vcvtb.f32.f16 s0, s0 + vstr.32 s0, [r1] + add r0, r0, #2 + add r1, r1, #4 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S new file mode 100644 index 00000000..85ac9d7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Float32ToFloat16.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global Float32ToFloat16 +#ifndef __APPLE__ + .type Float32ToFloat16, %function +#endif + +// void Float32ToFloat16(const float *input, float16_t *output, int number); +// r0: input, r1: output, r2: number +Float32ToFloat16: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vld1.32 {q2, q3}, [r0]! + vcvt.f16.f32 d2, q2 + vcvt.f16.f32 d3, q3 + vst1.16 {q0, q1}, [r1]! + subs r2, r2, #16 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vst1.16 {q0}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr s0, [r0] + vcvtb.f16.f32 s0, s0 + vstr.16 s0, [r1] + add r0, r0, #4 + add r1, r1, #2 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S new file mode 100644 index 00000000..1fed588a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/MatVecMulFp16.S @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) { +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulA32NeonFp16 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + add r10, r5, r5 // stride = depth * sizeof(float16_t) + mov lr, #4 + mul r11, r10, lr // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth8: + vld1.16 {q8}, [r7]! + add lr, r9, r10 + vld1.16 {q0}, [r9]! + vld1.16 {q1}, [lr], r10 + vld1.16 {q2}, [lr], r10 + vld1.16 {q3}, [lr] + + vmla.f16 q9, q8, q0 + vmla.f16 q10, q8, q1 + vmla.f16 q11, q8, q2 + vmla.f16 q12, q8, q3 + sub r8, r8, #8 + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + b AddC4 + + Col4Depth4: + vld1.16 {d16}, [r7]! + add lr, r9, r10 + vld1.16 {d0}, [r9]! + vld1.16 {d2}, [lr], r10 + vld1.16 {d4}, [lr], r10 + vld1.16 {d6}, [lr] + + vmla.f16 d18, d16, d0 + vmla.f16 d20, d16, d2 + vmla.f16 d22, d16, d4 + vmla.f16 d24, d16, d6 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + AddC4: + vpadd.f16 d0, d18, d19 + vpadd.f16 d1, d20, d21 + vpadd.f16 d2, d22, d23 + vpadd.f16 d4, d24, d25 + vpadd.f16 d30, d0, d1 + vpadd.f16 d31, d2, d4 + vpadd.f16 d30, d30, d31 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth1: + vld1.16 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.16 {d2[0]}, [r9]! + vld1.16 {d2[1]}, [lr], r10 + vld1.16 {d2[2]}, [lr], r10 + vld1.16 {d2[3]}, [lr] + + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.16 {d26}, [r3]! + vadd.f16 d30, d30, d26 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i16 q12, #6 + vcvt.f16.s16 q12, q12 + vmin.f16 d30, d30, d24 + + Col4Relu: + veor q13, q13, q13 + vmax.f16 d30, d30, d26 + + Col4Write: + vst1.16 {d30}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q15, q15, q15 + + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth8: + vld1.16 {q0}, [r7]! + vld1.16 {q1}, [r9]! + vmla.f16 q10, q1, q0 + sub r8, r8, #8 + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + b AddC1 + + Col1Depth4: + vld1.16 {d0}, [r7]! + vld1.16 {d2}, [r9]! + vmla.f16 d20, d2, d0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + AddC1: + vpadd.f16 d30, d20, d21 + vpadd.f16 d30, d30, d20 + vpadd.f16 d30, d30, d20 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth1: + vld1.16 {d0[0]}, [r7]! + vld1.16 {d2[0]}, [r9]! + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.16 {d28[0]}, [r3]! + vadd.f16 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i16 d26, #6 + vcvt.f16.s16 d26, d26 + vmin.f16 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f16 d30, d30, d24 + + Col1Write: + vst1.16 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S new file mode 100644 index 00000000..781b8c3b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S @@ -0,0 +1,617 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + .text + .align 5 + .global MatMul12x8A32Fp16 +#ifndef __APPLE__ + .type MatMul12x8A32Fp16, %function +#endif + +// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, +// int deep, int row, int col, int stride, bool write_mode); +// r0: a +// r1: b +// r2: dst +// r3: bias +// #4: depth +// #8: row +// #12: col +// #16: stride +// #20: writeNhwc/writeWino + +asm_function MatMul12x8A32Fp16 + // r13(sp) and r15(pc) can not be used!! + // r9 r4 is tmp register + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r3-r11, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + ldr lr, [sp, #20] + + mov r10, r1 // b + mov r11, r0 // a + mov r12, r2 // dst + + cmp lr, #2 + bne NoWinograd + mul r4, r8, r7 // stride * col + add r4, r4, r4 // r4 * sizeof(float16_t) + mov r9, #16 + mul r9, r8, r9 // stride * 8 * sizeof(float16_t) +NoWinograd: + add r8, r8, r8 // stride * sizeof(float16_t) + +a .req r0 +weight .req r1 +dst .req r2 +bias .req r3 +depth .req r5 +row .req r6 +col .req r7 +stride .req r8 +b_tmp .req r10 +a_tmp .req r11 +dst_tmp .req r12 + +.macro STORE_12x8 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x7 p1, p2, p3 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride +.endm + +.macro STORE_12x6 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x5 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x4 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x3 p1, p2 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x2 p1 + vst1.32 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x1 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_C8 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C7 p1, p2, p3, p4 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride + cmp row, \p4 + beq WriteEnd +.endm + +.macro STORE_C6 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C5 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C4 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C3 p1, p2, p3 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C2 p1, p2 + vst1.32 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +.macro STORE_C1 p1, p2 + vst1.16 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +LoopRow12: + ldr bias, [sp, #-40] + LoopCol8: + mov dst, dst_tmp + mov a, a_tmp + ldr depth, [sp, #4] + veor q4, q4, q4 + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + LoopDepth: + vld1.16 {q0, d2}, [a]! + vld1.16 {q2}, [weight]! + vmla.f16 q4, q2, d0[0] + vmla.f16 q5, q2, d0[1] + vmla.f16 q6, q2, d0[2] + vmla.f16 q7, q2, d0[3] + vmla.f16 q8, q2, d1[0] + vmla.f16 q9, q2, d1[1] + vmla.f16 q10, q2, d1[2] + vmla.f16 q11, q2, d1[3] + vmla.f16 q12, q2, d2[0] + vmla.f16 q13, q2, d2[1] + vmla.f16 q14, q2, d2[2] + vmla.f16 q15, q2, d2[3] + + subs depth, depth, #1 + bne LoopDepth + + Bias: + cmp bias, #0 + beq Activation + vld1.16 {q0}, [bias]! + vadd.f16 q4, q4, q0 + vadd.f16 q5, q5, q0 + vadd.f16 q6, q6, q0 + vadd.f16 q7, q7, q0 + vadd.f16 q8, q8, q0 + vadd.f16 q9, q9, q0 + vadd.f16 q10, q10, q0 + vadd.f16 q11, q11, q0 + vadd.f16 q12, q12, q0 + vadd.f16 q13, q13, q0 + vadd.f16 q14, q14, q0 + vadd.f16 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i16 q2, #0x4600 + vadd.f16 q4, q4, q2 + vadd.f16 q5, q5, q2 + vadd.f16 q6, q6, q2 + vadd.f16 q7, q7, q2 + vmin.f16 q8, q8, q2 + vmin.f16 q9, q9, q2 + vmin.f16 q10, q10, q2 + vmin.f16 q11, q11, q2 + vmin.f16 q12, q12, q2 + vmin.f16 q13, q13, q2 + vmin.f16 q14, q14, q2 + vmin.f16 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f16 q4, q4, q3 + vmax.f16 q5, q5, q3 + vmax.f16 q6, q6, q3 + vmax.f16 q7, q7, q3 + vmax.f16 q8, q8, q3 + vmax.f16 q9, q9, q3 + vmax.f16 q10, q10, q3 + vmax.f16 q11, q11, q3 + vmax.f16 q12, q12, q3 + vmax.f16 q13, q13, q3 + vmax.f16 q14, q14, q3 + vmax.f16 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWinograd + cmp row, #12 + bge Write12xCol + b WriteRowxCol + + WriteWinograd: + vst1.16 {q4}, [dst] + add dst, dst, r4 + vst1.16 {q5}, [dst] + add dst, dst, r4 + vst1.16 {q6}, [dst] + add dst, dst, r4 + vst1.16 {q7}, [dst] + add dst, dst, r4 + vst1.16 {q8}, [dst] + add dst, dst, r4 + vst1.16 {q9}, [dst] + add dst, dst, r4 + vst1.16 {q10}, [dst] + add dst, dst, r4 + vst1.16 {q11}, [dst] + add dst, dst, r4 + vst1.16 {q12}, [dst] + add dst, dst, r4 + vst1.16 {q13}, [dst] + add dst, dst, r4 + vst1.16 {q14}, [dst] + add dst, dst, r4 + vst1.16 {q15}, [dst] + add dst_tmp, dst_tmp, r9 + b WriteEnd + Write12xCol: + cmp col, #8 + bge Write12x8 + cmp col, #1 + beq Write12x1 + cmp col, #2 + beq Write12x2 + cmp col, #3 + beq Write12x3 + cmp col, #4 + beq Write12x4 + cmp col, #5 + beq Write12x5 + cmp col, #6 + beq Write12x6 + b Write12x7 + + WriteRowxCol: + cmp col, #8 + bge WriteRowx8 + cmp col, #1 + beq WriteRowx1 + cmp col, #2 + beq WriteRowx2 + cmp col, #3 + beq WriteRowx3 + cmp col, #4 + beq WriteRowx4 + cmp col, #5 + beq WriteRowx5 + cmp col, #6 + beq WriteRowx6 + b WriteRowx7 + + Write12x8: + STORE_12x8 q4 + STORE_12x8 q5 + STORE_12x8 q6 + STORE_12x8 q7 + STORE_12x8 q8 + STORE_12x8 q9 + STORE_12x8 q10 + STORE_12x8 q11 + STORE_12x8 q12 + STORE_12x8 q13 + STORE_12x8 q14 + STORE_12x8 q15 + b WriteEnd + WriteRowx8: + STORE_C8 q4, #1 + STORE_C8 q5, #2 + STORE_C8 q6, #3 + STORE_C8 q7, #4 + STORE_C8 q8, #5 + STORE_C8 q9, #6 + STORE_C8 q10, #7 + STORE_C8 q11, #8 + STORE_C8 q12, #9 + STORE_C8 q13, #10 + STORE_C8 q14, #11 + STORE_C8 q15, #12 + b WriteEnd + + Write12x1: + STORE_12x1 d8[0] + STORE_12x1 d10[0] + STORE_12x1 d12[0] + STORE_12x1 d14[0] + STORE_12x1 d16[0] + STORE_12x1 d18[0] + STORE_12x1 d20[0] + STORE_12x1 d22[0] + STORE_12x1 d24[0] + STORE_12x1 d26[0] + STORE_12x1 d28[0] + STORE_12x1 d30[0] + b WriteEnd + WriteRowx1: + STORE_C1 d8[0], #1 + STORE_C1 d10[0], #2 + STORE_C1 d12[0], #3 + STORE_C1 d14[0], #4 + STORE_C1 d16[0], #5 + STORE_C1 d18[0], #6 + STORE_C1 d20[0], #7 + STORE_C1 d22[0], #8 + STORE_C1 d24[0], #9 + STORE_C1 d26[0], #10 + STORE_C1 d28[0], #11 + STORE_C1 d30[0], #12 + b WriteEnd + + Write12x2: + STORE_12x2 d8[0] + STORE_12x2 d10[0] + STORE_12x2 d12[0] + STORE_12x2 d14[0] + STORE_12x2 d16[0] + STORE_12x2 d18[0] + STORE_12x2 d20[0] + STORE_12x2 d22[0] + STORE_12x2 d24[0] + STORE_12x2 d26[0] + STORE_12x2 d28[0] + STORE_12x2 d30[0] + b WriteEnd + WriteRowx2: + STORE_C2 d8[0], #1 + STORE_C2 d10[0], #2 + STORE_C2 d12[0], #3 + STORE_C2 d14[0], #4 + STORE_C2 d16[0], #5 + STORE_C2 d18[0], #6 + STORE_C2 d20[0], #7 + STORE_C2 d22[0], #8 + STORE_C2 d24[0], #9 + STORE_C2 d26[0], #10 + STORE_C2 d28[0], #11 + STORE_C2 d30[0], #12 + b WriteEnd + + Write12x3: + STORE_12x3 d8[0], d8[2] + STORE_12x3 d10[0], d10[2] + STORE_12x3 d12[0], d12[2] + STORE_12x3 d14[0], d14[2] + STORE_12x3 d16[0], d16[2] + STORE_12x3 d18[0], d18[2] + STORE_12x3 d20[0], d20[2] + STORE_12x3 d22[0], d22[2] + STORE_12x3 d24[0], d24[2] + STORE_12x3 d26[0], d26[2] + STORE_12x3 d28[0], d28[2] + STORE_12x3 d30[0], d30[2] + b WriteEnd + WriteRowx3: + STORE_C3 d8[0], d8[2], #1 + STORE_C3 d10[0], d10[2], #2 + STORE_C3 d12[0], d12[2], #3 + STORE_C3 d14[0], d14[2], #4 + STORE_C3 d16[0], d16[2], #5 + STORE_C3 d18[0], d18[2], #6 + STORE_C3 d20[0], d20[2], #7 + STORE_C3 d22[0], d22[2], #8 + STORE_C3 d24[0], d24[2], #9 + STORE_C3 d26[0], d26[2], #10 + STORE_C3 d28[0], d28[2], #11 + STORE_C3 d30[0], d30[2], #12 + b WriteEnd + + Write12x4: + STORE_12x4 d8 + STORE_12x4 d10 + STORE_12x4 d12 + STORE_12x4 d14 + STORE_12x4 d16 + STORE_12x4 d18 + STORE_12x4 d20 + STORE_12x4 d22 + STORE_12x4 d24 + STORE_12x4 d26 + STORE_12x4 d28 + STORE_12x4 d30 + b WriteEnd + WriteRowx4: + STORE_C4 d8, #1 + STORE_C4 d10, #2 + STORE_C4 d12, #3 + STORE_C4 d14, #4 + STORE_C4 d16, #5 + STORE_C4 d18, #6 + STORE_C4 d20, #7 + STORE_C4 d22, #8 + STORE_C4 d24, #9 + STORE_C4 d26, #10 + STORE_C4 d28, #11 + STORE_C4 d30, #12 + b WriteEnd + + Write12x5: + STORE_12x5 d8, d9[0] + STORE_12x5 d10, d11[0] + STORE_12x5 d12, d13[0] + STORE_12x5 d14, d15[0] + STORE_12x5 d16, d17[0] + STORE_12x5 d18, d19[0] + STORE_12x5 d20, d21[0] + STORE_12x5 d22, d23[0] + STORE_12x5 d24, d25[0] + STORE_12x5 d26, d27[0] + STORE_12x5 d28, d29[0] + STORE_12x5 d30, d31[0] + b WriteEnd + WriteRowx5: + STORE_C5 d8, d9[0], #1 + STORE_C5 d10, d11[0], #2 + STORE_C5 d12, d13[0], #3 + STORE_C5 d14, d15[0], #4 + STORE_C5 d16, d17[0], #5 + STORE_C5 d18, d19[0], #6 + STORE_C5 d20, d21[0], #7 + STORE_C5 d22, d23[0], #8 + STORE_C5 d24, d25[0], #9 + STORE_C5 d26, d27[0], #10 + STORE_C5 d28, d29[0], #11 + STORE_C5 d30, d31[0], #12 + b WriteEnd + + Write12x6: + STORE_12x6 d8, d9[0] + STORE_12x6 d10, d11[0] + STORE_12x6 d12, d13[0] + STORE_12x6 d14, d15[0] + STORE_12x6 d16, d17[0] + STORE_12x6 d18, d19[0] + STORE_12x6 d20, d21[0] + STORE_12x6 d22, d23[0] + STORE_12x6 d24, d25[0] + STORE_12x6 d26, d27[0] + STORE_12x6 d28, d29[0] + STORE_12x6 d30, d31[0] + b WriteEnd + WriteRowx6: + STORE_C6 d8, d9[0], #1 + STORE_C6 d10, d11[0], #2 + STORE_C6 d12, d13[0], #3 + STORE_C6 d14, d15[0], #4 + STORE_C6 d16, d17[0], #5 + STORE_C6 d18, d19[0], #6 + STORE_C6 d20, d21[0], #7 + STORE_C6 d22, d23[0], #8 + STORE_C6 d24, d25[0], #9 + STORE_C6 d26, d27[0], #10 + STORE_C6 d28, d29[0], #11 + STORE_C6 d30, d31[0], #12 + b WriteEnd + + Write12x7: + STORE_12x7 d8, d9[0], d9[2] + STORE_12x7 d10, d11[0], d11[2] + STORE_12x7 d12, d13[0], d13[2] + STORE_12x7 d14, d15[0], d15[2] + STORE_12x7 d16, d17[0], d17[2] + STORE_12x7 d18, d19[0], d19[2] + STORE_12x7 d20, d21[0], d21[2] + STORE_12x7 d22, d23[0], d23[2] + STORE_12x7 d24, d25[0], d25[2] + STORE_12x7 d26, d27[0], d27[2] + STORE_12x7 d28, d29[0], d29[2] + STORE_12x7 d30, d31[0], d31[2] + b WriteEnd + WriteRowx7: + STORE_C7 d8, d9[0], d9[2], #1 + STORE_C7 d10, d11[0], d11[2], #2 + STORE_C7 d12, d13[0], d13[2], #3 + STORE_C7 d14, d15[0], d15[2], #4 + STORE_C7 d16, d17[0], d17[2], #5 + STORE_C7 d18, d19[0], d19[2], #6 + STORE_C7 d20, d21[0], d21[2], #7 + STORE_C7 d22, d23[0], d23[2], #8 + STORE_C7 d24, d25[0], d25[2], #9 + STORE_C7 d26, d27[0], d27[2], #10 + STORE_C7 d28, d29[0], d29[2], #11 + STORE_C7 d30, d31[0], d31[2], #12 + b WriteEnd + + WriteEnd: + cmp col, #8 + ble LoopColEnd + sub col, col, #8 + ldr lr, [sp, #20] + cmp lr, #2 + beq LoopCol8 + add dst_tmp, dst_tmp, #16 + b LoopCol8 + LoopColEnd: + cmp row, #12 + ble LoopRowEnd + sub row, row, #12 + mov a_tmp, a + mov weight, b_tmp + ldr lr, [sp, #20] + cmp lr, #2 + beq WinogradDst + ldr lr, [sp, #12] + sub lr, lr, col + add lr, lr, lr // col *= 2 + sub dst_tmp, dst, lr + b LoopRow + WinogradDst: + add dst_tmp, dst, r9 + LoopRow: + mov dst, dst_tmp + ldr col, [sp, #12] + b LoopRow12 +LoopRowEnd: + sub sp, sp, #104 + vpop {q4-q7} + pop {r3-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S new file mode 100644 index 00000000..fa32c368 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 +// void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, +// size_t oc4); +// r0: dst +// r1: src +// r2: weight +// r3: cal_num +// r4(sp): ic4 +// r5(sp + #4): oc4 +push {r4-r11, lr} +vpush {q4-q7} +add sp, sp, #100 +ldr r4, [sp] +ldr r5, [sp, #4] // oc4 +add r3, r3, r3 +mov r7, r1 + +cmp r5, #1 +blt LoopOCEnd +cmp r4, #1 +blt LoopICEnd +LoopOC: + ldr r4, [sp] + veor q15, q15, q15 + veor q14, q14, q14 + veor q13, q13, q13 + veor q12, q12, q12 + LoopIC: + vld1.16 {q4, q5}, [r2]! // weight + vld1.16 {q2, q3}, [r1]! // 16 number src + vmla.f16 d24, d8, d4[0] + vmla.f16 d24, d9, d4[1] + vmla.f16 d24, d10, d4[2] + vmla.f16 d24, d11, d4[3] + + vmla.f16 d25, d8, d5[0] + vmla.f16 d25, d9, d5[1] + vmla.f16 d25, d10, d5[2] + vmla.f16 d25, d11, d5[3] + + vmla.f16 d26, d8, d6[0] + vmla.f16 d26, d9, d6[1] + vmla.f16 d26, d10, d6[2] + vmla.f16 d26, d11, d6[3] + + vmla.f16 d27, d8, d7[0] + vmla.f16 d27, d9, d7[1] + vmla.f16 d27, d10, d7[2] + vmla.f16 d27, d11, d7[3] + + vld1.16 {q0, q1}, [r1]! // 16 number src + vmla.f16 d28, d8, d0[0] + vmla.f16 d28, d9, d0[1] + vmla.f16 d28, d10, d0[2] + vmla.f16 d28, d11, d0[3] + + vmla.f16 d29, d8, d1[0] + vmla.f16 d29, d9, d1[1] + vmla.f16 d29, d10, d1[2] + vmla.f16 d29, d11, d1[3] + + vmla.f16 d30, d8, d2[0] + vmla.f16 d30, d9, d2[1] + vmla.f16 d30, d10, d2[2] + vmla.f16 d30, d11, d2[3] + + vmla.f16 d31, d8, d3[0] + vmla.f16 d31, d9, d3[1] + vmla.f16 d31, d10, d3[2] + vmla.f16 d31, d11, d3[3] + + subs r4, r4, #1 + bne LoopIC + b LoopICEnd + LoopICEnd: + mov lr, r0 + vst1.16 {q12, q13}, [lr]! + vst1.16 {q14, q15}, [lr]! + add r0, r0, r3 // dst += cal_num + mov r1, r7 + subs r5, r5, #1 + bne LoopOC +LoopOCEnd: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S new file mode 100644 index 00000000..334ff48e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransLeft.S @@ -0,0 +1,165 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransLeftFp16 + push {r0, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r3, r8 // step for S + add r10, r4, r4 // step for B + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r0, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r7 + add r14, r12, r7 + add r9, r14, r7 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r7 + add r9, r12, r7 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r11, r11, r8 + add r11, r11, r7 + b LoopK1 + LoopKEnd: + add r0, r0, r8 // S += unitstep + subs r3, r3, #1 + bne LoopW + LoopWEnd: + subs r4, r4, #1 + beq LoopHEnd + add r1, r1, #2 // B += 1 + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r0, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S new file mode 100644 index 00000000..cb3a297a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/arm82_aarch32_fp16/WinogradTransRight.S @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransRightFp16 + push {r1, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r5, r8 // step for S = k * unitStep * 4 + add r10, r4, r4 // step for B = 2 * h + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r1, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r8 + add r14, r12, r8 + add r9, r14, r8 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r8 + add r9, r12, r8 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + b LoopK1 + + LoopKEnd: + add r1, r1, #2 // B[x] + subs r3, r3, #1 + bne LoopW + LoopWEnd: + add r0, r0, r7 + subs r4, r4, #1 + beq LoopHEnd + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r1, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S new file mode 100644 index 00000000..0e9eac3c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32Avx3x3.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" +.text +.align 4 + +// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width, +// size_t input_stride, size_t relum, szie_t relu6) +// in linux x64 platform: +// rdi: output +// rsi: input +// rdx: weights +// rcx: bias +// r8: channels +// r9: output_width +// 8: input_stride +// 16: relu +// 24: relu6 + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output +// rdx: input +// r8: weights +// r9: bias +// 40: channels +// 48: output_width +// 56: input_stride +// 64: relu +// 72: relu6 +asm_function ConvDwFp32Avx3x3 + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + addq $96, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // channels + movq 48(%rsp), %r9 // output_width + + mov %rdx, -80(%rsp) + mov %rcx, -72(%rsp) + mov %r9, -56(%rsp) + mov %r8, -64(%rsp) + movq 56(%rsp), %rbp // input_stride + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // relu + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // relu6 + movq %rbp, 24(%rsp) +#endif + + movq $6, %rax + vcvtsi2ss %rax, %xmm15, %xmm15 + vshufps $0, %xmm15, %xmm15, %xmm15 + vinsertf128 $1, %xmm15, %ymm15, %ymm15 + vxorps %ymm14, %ymm14, %ymm14 + + LoopPixel: + movq -80(%rsp), %rdx + movq -72(%rsp), %rcx + movq -64(%rsp), %r8 + movq (%rsi), %r9 + movq 8(%rsi), %r10 + movq 16(%rsi), %r11 + movq 24(%rsi), %r12 + movq 32(%rsi), %r13 + movq 40(%rsi), %r14 + movq 48(%rsi), %r15 + movq 56(%rsi), %rbp + movq 64(%rsi), %rbx + + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%r11), %ymm2 + addq $32, %r11 + + vmovups (%rdx), %ymm11 + addq $32, %rdx + vmovups (%rdx), %ymm12 + addq $32, %rdx + vmovups (%rdx), %ymm13 + addq $32, %rdx + + vmovups (%rcx), %ymm10 + addq $32, %rcx + + cmpq $8, %r8 + jbe LeftLoop + LoopC8: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm7, %ymm10 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm8, %ymm10 + vmovups (%r11), %ymm2 + addq $32, %r11 + vmovups (%rdx), %ymm13 + addq $32, %rdx + + movq 24(%rsp), %rax + cmpq $0, %rax + jne Relu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne Relu + jmp Write + Relu6: + vminps %ymm15, %ymm10, %ymm10 + Relu: + vmaxps %ymm14, %ymm10, %ymm10 + Write: + vmovups %ymm10, (%rdi) + addq $32, %rdi + + vmovups (%rcx), %ymm10 + addq $32, %rcx + subq $8, %r8 + cmpq $8, %r8 + ja LoopC8 + + LeftLoop: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vfmadd231ps %ymm12, %ymm7, %ymm10 + vfmadd231ps %ymm13, %ymm8, %ymm10 + + movq 24(%rsp), %rax + cmpq $0, %rax + jne LeftRelu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne LeftRelu + jmp LeftWrite + LeftRelu6: + vminps %ymm15, %ymm10, %ymm10 + LeftRelu: + vmaxps %ymm14, %ymm10, %ymm10 + LeftWrite: + cmpq $1, %r8 + je Write1 + cmpq $2, %r8 + je Write2 + cmpq $3, %r8 + je Write3 + cmpq $4, %r8 + je Write4 + cmpq $5, %r8 + je Write5 + cmpq $6, %r8 + je Write6 + cmpq $7, %r8 + je Write7 + jmp Write8 + Write1: + vmovss %xmm10, (%rdi) + addq $4, %rdi + jmp NextPixel + Write2: + vmovsd %xmm10, (%rdi) + addq $8, %rdi + jmp NextPixel + Write3: + vmovsd %xmm10, (%rdi) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdi) + addq $12, %rdi + jmp NextPixel + Write4: + vmovups %xmm10, (%rdi) + addq $16, %rdi + jmp NextPixel + Write5: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovss %xmm9, 16(%rdi) + addq $20, %rdi + jmp NextPixel + Write6: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + addq $24, %rdi + jmp NextPixel + Write7: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 24(%rdi) + addq $28, %rdi + jmp NextPixel + Write8: + vmovups %ymm10, (%rdi) + add $32, %rdi + + NextPixel: + movq 8(%rsp), %rbp + addq %rbp, %rsi + movq -56(%rsp), %rax + subq $1, %rax + movq %rax, -56(%rsp) + cmpq $0, %rax + ja LoopPixel +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S new file mode 100644 index 00000000..8e6c938d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32BorderAvx.S @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6); + +asm_function ConvDwFp32Border + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq %rdi, %rdx +#ifdef _WIN32 + movq %rcx, %rdx +#endif + movq 8(%rdx), %r12 // src + movq 16(%rdx), %r13 // weight + movq 24(%rdx), %rbp // bias + movq 32(%rdx), %r11 // height + movq 40(%rdx), %r10 + movq %r10, -72(%rsp) // width + movq 48(%rdx), %r10 + movq %r10, -80(%rsp) // in_kh_step + movq 56(%rdx), %r10 // in_kw_step + movq 64(%rdx), %rax // kernel_w + movq 72(%rdx), %rcx // relu + movq 80(%rdx), %rbx // reul6 + movq $6, -64(%rsp) + movq (%rdx), %rdx + cmpq $0, %r11 + je End + + xorps %xmm8, %xmm8 + LoopHeight: + movq %r12, %rsi // src_kh, src_kw + movq %r13, %rdi // weight_kh, weight_kw + movq -72(%rsp), %r8 // width + + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $1, %r8 + jae LoopWidth1 + jmp LoopWidthEnd + + LoopWidth6: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + vmovups (%rsi, %r10, 4), %xmm4 + vmovups (%r9, %r10, 2), %xmm5 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + vfmadd231ps 64(%rdi), %xmm4, %xmm7 + vfmadd231ps 80(%rdi), %xmm5, %xmm8 + + addps %xmm6, %xmm7 + imul $6, %r10, %r15 + addq $96, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $6, %r8 + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth4: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10, 1), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + + addps %xmm6, %xmm7 + imul $4, %r10, %r15 + addq $64, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $4, %r8 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth1: + vmovups (%rsi), %xmm0 // input_tmp + addq %r10, %rsi + vfmadd231ps (%rdi), %xmm0, %xmm8 + addq $16, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopWidth1 + jmp LoopWidthEnd + + LoopWidthEnd: + subq $1, %r11 + cmpq $0, %r11 + je LoopHeightEnd + addq -80(%rsp), %r12 // in_kh_step + addq %rax, %r13 // kernel_w_step + jmp LoopHeight + + LoopHeightEnd: + xorps %xmm10, %xmm10 + vbroadcastss -64(%rsp), %xmm9 + + addps (%rbp), %xmm8 + cmpq $1, %rbx + je Relu6 + cmpq $1, %rcx + je Relu + jmp Write + Relu6: + minps %xmm9, %xmm8 + Relu: + maxps %xmm10, %xmm8 + Write: + movups %xmm8, (%rdx) +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S new file mode 100644 index 00000000..2b936afb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowAvx.S @@ -0,0 +1,189 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + cmpq $0, %rcx + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC32: + vmovups (%rsi), %ymm0 // input_tmp + vmovups 32(%rsi), %ymm1 + vmovups 64(%rsi), %ymm2 + vmovups 96(%rsi), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rdi), %ymm9 + vmovups 64(%rdi), %ymm10 + vmovups 96(%rdi), %ymm11 + + addq $128, %rsi + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + vfmadd231ps 64(%rdx), %ymm2, %ymm10 + vfmadd231ps 96(%rdx), %ymm3, %ymm11 + + vmovups %ymm8, (%rdi) // output_ptr + vmovups %ymm9, 32(%rdi) + vmovups %ymm10, 64(%rdi) + vmovups %ymm11, 96(%rdi) + addq $128, %rdi + addq $128, %rdx + + subq $32, %r8 + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rsi), %ymm1 + vmovups 32(%rdi), %ymm9 + addq $64, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + + vmovups %ymm8, (%rdi) // output_ptr + addq $64, %rdx + vmovups %ymm9, 32(%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S new file mode 100644 index 00000000..9492bd6a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/ConvDwFp32RowOptAVX.S @@ -0,0 +1,382 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVXFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVXFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S new file mode 100644 index 00000000..85a64041 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx/MatmulAvx.S @@ -0,0 +1,993 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// parameters pass in Linux x86 platform: +// rdi: a +// rsi: b +// rdx: c +// rcx: bias +// r8: act_type +// r9: depth +// 8: row +// 16: col +// 24: stride +// 32: writeNhwc/writeWino + +// parameters pass in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: a +// rdx: b +// r8: c +// r9: bias +// 40: act_type +// 48: depth +// 56: row +// 64: col +// 72: stride +// 80: writeMode + +asm_function MatmulFloatAvxOpt + // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + pushq %rsi // -104 rsi + pushq %rdi // -112 rdi + addq $112, %rsp +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // act_type + movq 48(%rsp), %r9 // depth + movq %r9, -56(%rsp) // r9 + movq %rcx, -72(%rsp) // rcx + movq %rdx, -80(%rsp) // rdx + movq %rsi, -88(%rsp) // rsi + movq %rdi, -96(%rsp) // rdi + + movq 56(%rsp), %rbp // row + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // col + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // stride + movq %rbp, 24(%rsp) + movq 80(%rsp), %rbp // weiteMode + movq %rbp, 32(%rsp) +#endif + movq 8(%rsp), %rbp + movq 16(%rsp), %rbx + movq 24(%rsp), %r10 + movq 32(%rsp), %r14 + + movq $24, %r11 + imul %r9, %r11 + cmpq $0, %r14 + jne NoC8Steps + movq $48, %r13 + imul %rbp, %r13 +NoC8Steps: + cmpq $2, %r14 + jne NoWinoSteps + movq $4, %r12 + imul %r10, %r12 + imul %rbx, %r12 + movq $32, %r13 + imul %r10, %r13 +NoWinoSteps: + movq $4, %rax + imul %rax, %r10 + +LoopRow: + movq -88(%rsp), %rsi + movq 16(%rsp), %rbx + movq -72(%rsp), %rcx + + LoopCol: + cmpq $0, %r14 + je NoReloadDst + movq -80(%rsp), %rdx + NoReloadDst: + movq -96(%rsp), %rdi + movq -56(%rsp), %r9 + + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm10 + vbroadcastss 4(%rdi), %ymm11 + vbroadcastss 8(%rdi), %ymm12 + vbroadcastss 12(%rdi), %ymm13 + vbroadcastss 16(%rdi), %ymm2 + vbroadcastss 20(%rdi), %ymm3 + addq $64, %rsi + vmulps %ymm0, %ymm10, %ymm4 + vmulps %ymm1, %ymm10, %ymm5 + vmulps %ymm0, %ymm11, %ymm6 + vmulps %ymm1, %ymm11, %ymm7 + vmulps %ymm0, %ymm12, %ymm8 + vmulps %ymm1, %ymm12, %ymm9 + vmulps %ymm0, %ymm13, %ymm10 + vmulps %ymm1, %ymm13, %ymm11 + add $24, %rdi + vmulps %ymm0, %ymm2, %ymm12 + vmulps %ymm1, %ymm2, %ymm13 + vmulps %ymm0, %ymm3, %ymm14 + vmulps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + je Bias + + LoopDepth: + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm2 + vbroadcastss 4(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm4 + addq $64, %rsi + vfmadd231ps %ymm1, %ymm2, %ymm5 + vbroadcastss 8(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm6 + vfmadd231ps %ymm1, %ymm3, %ymm7 + vbroadcastss 12(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm8 + prefetcht0 384(%rsi) + vfmadd231ps %ymm1, %ymm2, %ymm9 + vbroadcastss 16(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm10 + vfmadd231ps %ymm1, %ymm3, %ymm11 + vbroadcastss 20(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm12 + vfmadd231ps %ymm1, %ymm2, %ymm13 + addq $24, %rdi + vfmadd231ps %ymm0, %ymm3, %ymm14 + vfmadd231ps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + ja LoopDepth + + Bias: + cmpq $0, %rcx + je Activation + vmovups (%rcx), %ymm0 + vmovups 32(%rcx), %ymm1 + add $64, %rcx + vaddps %ymm0, %ymm4, %ymm4 + vaddps %ymm1, %ymm5, %ymm5 + vaddps %ymm0, %ymm6, %ymm6 + vaddps %ymm1, %ymm7, %ymm7 + vaddps %ymm0, %ymm8, %ymm8 + vaddps %ymm1, %ymm9, %ymm9 + vaddps %ymm0, %ymm10, %ymm10 + vaddps %ymm1, %ymm11, %ymm11 + vaddps %ymm0, %ymm12, %ymm12 + vaddps %ymm1, %ymm13, %ymm13 + vaddps %ymm0, %ymm14, %ymm14 + vaddps %ymm1, %ymm15, %ymm15 + + Activation: + cmpq $3, %r8 + je Relu6 + cmpq $1, %r8 + je Relu + jmp Write + + Relu6: + movq $6, %rax + vcvtsi2ss %rax, %xmm0, %xmm0 + vshufps $0, %xmm0, %xmm0, %xmm0 + vinsertf128 $1, %xmm0, %ymm0, %ymm0 + vminps %ymm0, %ymm4, %ymm4 + vminps %ymm0, %ymm5, %ymm5 + vminps %ymm0, %ymm6, %ymm6 + vminps %ymm0, %ymm7, %ymm7 + vminps %ymm0, %ymm8, %ymm8 + vminps %ymm0, %ymm9, %ymm9 + vminps %ymm0, %ymm10, %ymm10 + vminps %ymm0, %ymm11, %ymm11 + vminps %ymm0, %ymm12, %ymm12 + vminps %ymm0, %ymm13, %ymm13 + vminps %ymm0, %ymm14, %ymm14 + vminps %ymm0, %ymm15, %ymm15 + + Relu: + vxorps %ymm1, %ymm1, %ymm1 + vmaxps %ymm1, %ymm4, %ymm4 + vmaxps %ymm1, %ymm5, %ymm5 + vmaxps %ymm1, %ymm6, %ymm6 + vmaxps %ymm1, %ymm7, %ymm7 + vmaxps %ymm1, %ymm8, %ymm8 + vmaxps %ymm1, %ymm9, %ymm9 + vmaxps %ymm1, %ymm10, %ymm10 + vmaxps %ymm1, %ymm11, %ymm11 + vmaxps %ymm1, %ymm12, %ymm12 + vmaxps %ymm1, %ymm13, %ymm13 + vmaxps %ymm1, %ymm14, %ymm14 + vmaxps %ymm1, %ymm15, %ymm15 + + Write: + cmpq $2, %r14 + je WriteWino + cmpq $0, %r14 + je WriteC8 + cmpq $1, %rbx + je Write1 + cmpq $2, %rbx + je Write2 + cmpq $3, %rbx + je Write3 + cmpq $4, %rbx + je Write4 + cmpq $5, %rbx + je Write5 + cmpq $6, %rbx + je Write6 + cmpq $7, %rbx + je Write7 + cmpq $8, %rbx + je Write8 + cmpq $9, %rbx + je Write9 + cmpq $10, %rbx + je Write10 + cmpq $11, %rbx + je Write11 + cmpq $12, %rbx + je Write12 + cmpq $13, %rbx + je Write13 + cmpq $14, %rbx + je Write14 + cmpq $15, %rbx + je Write15 + jmp Write16 + + Write1: + movq %rdx, %rax + addq $4, %rax + movq %rax, -80(%rsp) + vmovss %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm14, (%rdx) + addq %r10, %rdx + addq $4, %rdx + jmp WriteEnd + Write2: + movq %rdx, %rax + addq $8, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + addq %r10, %rdx + addq $8, %rdx + jmp WriteEnd + Write3: + movq %rdx, %rax + addq $12, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 8(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 8(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 8(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 8(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 8(%rdx) + addq %r10, %rdx + addq $12, %rdx + jmp WriteEnd + Write4: + movq %rdx, %rax + addq $16, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + addq %r10, %rdx + addq $16, %rdx + jmp WriteEnd + Write5: + movq %rdx, %rax + addq $20, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovss %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovss %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovss %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovss %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovss %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovss %xmm14, 16(%rdx) + addq %r10, %rdx + addq $20, %rdx + jmp WriteEnd + Write6: + movq %rdx, %rax + addq $24, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + addq %r10, %rdx + addq $24, %rdx + jmp WriteEnd + Write7: + movq %rdx, %rax + addq $28, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 24(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 24(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 24(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 24(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 24(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 24(%rdx) + addq %r10, %rdx + addq $28, %rdx + jmp WriteEnd + Write8: + movq %rdx, %rax + addq $32, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + addq %r10, %rdx + addq $32, %rdx + jmp WriteEnd + Write9: + movq %rdx, %rax + addq $36, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovss %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovss %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovss %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovss %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovss %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovss %xmm15, 32(%rdx) + addq %r10, %rdx + addq $36, %rdx + jmp WriteEnd + Write10: + movq %rdx, %rax + addq $40, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + addq %r10, %rdx + addq $40, %rdx + jmp WriteEnd + Write11: + movq %rdx, %rax + addq $44, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 40(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 40(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 40(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 40(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 40(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 40(%rdx) + addq %r10, %rdx + addq $44, %rdx + jmp WriteEnd + Write12: + movq %rdx, %rax + addq $48, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + addq %r10, %rdx + addq $48, %rdx + jmp WriteEnd + Write13: + movq %rdx, %rax + addq $52, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovss %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovss %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovss %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovss %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovss %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovss %xmm15, 48(%rdx) + addq %r10, %rdx + addq $52, %rdx + jmp WriteEnd + Write14: + movq %rdx, %rax + addq $56, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + addq %r10, %rdx + addq $56, %rdx + jmp WriteEnd + Write15: + movq %rdx, %rax + addq $60, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 56(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 56(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 56(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 56(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 56(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 56(%rdx) + addq %r10, %rdx + addq $60, %rdx + jmp WriteEnd + WriteC8: + movq %rdx, %rax + addq %r11, %rdx + movq %rdx, %r15 + addq %r11, %rdx + movq %rdx, -80(%rsp) + vmovups %ymm4, (%rax) + vmovups %ymm6, 32(%rax) + vmovups %ymm8, 64(%rax) + vmovups %ymm10, 96(%rax) + vmovups %ymm12, 128(%rax) + vmovups %ymm14, 160(%rax) + vmovups %ymm5, (%r15) + vmovups %ymm7, 32(%r15) + vmovups %ymm9, 64(%r15) + vmovups %ymm11, 96(%r15) + vmovups %ymm13, 128(%r15) + vmovups %ymm15, 160(%r15) + jmp WriteEnd + WriteWino: + movq %rdx, %rax + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + addq %r12, %rdx + vmovups %ymm6, (%rdx) + addq %r12, %rdx + vmovups %ymm8, (%rdx) + addq %r12, %rdx + vmovups %ymm10, (%rdx) + addq %r12, %rdx + vmovups %ymm12, (%rdx) + addq %r12, %rdx + vmovups %ymm14, (%rdx) + cmpq $8, %rbx + je WriteEnd + movq %rax, %rdx + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm5, (%rdx) + addq %r12, %rdx + vmovups %ymm7, (%rdx) + addq %r12, %rdx + vmovups %ymm9, (%rdx) + addq %r12, %rdx + vmovups %ymm11, (%rdx) + addq %r12, %rdx + vmovups %ymm13, (%rdx) + addq %r12, %rdx + vmovups %ymm15, (%rdx) + jmp WriteEnd + Write16: + movq %rdx, %rax + addq $64, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %ymm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %ymm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %ymm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %ymm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %ymm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %ymm15, 32(%rdx) + addq %r10, %rdx + addq $64, %rdx + + WriteEnd: + cmpq $16, %rbx + jbe LoopColEnd + subq $16, %rbx + jmp LoopCol + + LoopColEnd: + movq -96(%rsp), %rdi + addq %r11, %rdi + movq %rdi, -96(%rsp) + cmpq $0, %r14 + je C8DstStep + cmpq $2, %r14 + je WinoDstStep + movq $4, %rax + movq 16(%rsp), %rbx + imul %rbx, %rax + subq %rax, %rdx + movq %rdx, -80(%rsp) + jmp NoDstStep + C8DstStep: + movq -80(%rsp), %rax + addq $384, %rax + movq %rax, -80(%rsp) + jmp NoDstStep + WinoDstStep: + addq %r13, %rdx + movq %rdx, -80(%rsp) + NoDstStep: + cmpq $6, %rbp + jbe LoopRowEnd + subq $6, %rbp + jmp LoopRow + +LoopRowEnd: + subq $112, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rdx + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S new file mode 100644 index 00000000..7afdeb0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/avx512/ConvDwFp32RowAVX512.S @@ -0,0 +1,499 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX512 +#include "nnacl_c/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVX512Fp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%r10), %zmm8 // output_tmp + vmovups (%r10), %zmm9 // output_tmp + vmovups (%r10), %zmm10 // output_tmp + // vmovups (%r10), %zmm11 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%r10), %zmm8 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%rdi), %zmm8 // output_tmp + vmovups (%rdi, %r12), %zmm9 // output_tmp + vmovups (%rdi, %r12, 2), %zmm10 // output_tmp + // vmovups (%rdi, %r12, 3), %zmm11 // output_tmp + addq $64, %rsi + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rdi), %zmm8 // output_tmp + addq $64, %rsi + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S new file mode 100644 index 00000000..85b202ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/CalculateMinMaxFp16Count8.S @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +// x0: data +// w1: count_8 +// x2: real_min +// x3: real_max + +asm_function CalculateMinMaxCount8Fp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + + mov x4, x0 // reload data + mov w5, w1 // reload count + ld1 {v31.8h}, [x4] + ld1 {v30.8h}, [x4], #16 + subs w5, w5, #8 + ble Write + + Loop: + ld1 {v0.8h}, [x4], #16 + fmin v31.8h, v31.8h, v0.8h + fmax v30.8h, v30.8h, v0.8h + subs w5, w5, #8 + bgt Loop + + Write: + fminv h6, v31.8h + fmaxv h7, v30.8h + + str h6, [x2] + str h7, [x3] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S new file mode 100644 index 00000000..4822c93b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, +// size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.8h}, [x3] // bias + movi v1.8h, #0x46, lsl #8 // relu 6 + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.8h}, [x15], x7 + ld1 {v4.8h}, [x16], #16 + fmla v0.8h, v3.8h, v4.8h + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v1.8h + Relu: + fmax v0.8h, v0.8h, v2.8h + Write: + st1 {v0.8h}, [x0] + + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S new file mode 100644 index 00000000..1b253437 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Center.S @@ -0,0 +1,312 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.8h}, [x3] + movi v26.8h, #0x46, lsl #8 + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v8.8h, v16.8h, v25.8h + fmla v9.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v10.8h, v18.8h, v25.8h + fmla v11.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v12.8h, v20.8h, v25.8h + fmla v13.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v14.8h, v22.8h, v25.8h + fmla v15.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + fmin v8.8h, v8.8h, v26.8h + fmin v9.8h, v9.8h, v26.8h + fmin v10.8h, v10.8h, v26.8h + fmin v11.8h, v11.8h, v26.8h + fmin v12.8h, v12.8h, v26.8h + fmin v13.8h, v13.8h, v26.8h + fmin v14.8h, v14.8h, v26.8h + fmin v15.8h, v15.8h, v26.8h + Relu16: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + fmax v8.8h, v8.8h, v27.8h + fmax v9.8h, v9.8h, v27.8h + fmax v10.8h, v10.8h, v27.8h + fmax v11.8h, v11.8h, v27.8h + fmax v12.8h, v12.8h, v27.8h + fmax v13.8h, v13.8h, v27.8h + fmax v14.8h, v14.8h, v27.8h + fmax v15.8h, v15.8h, v27.8h + Write16: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + st1 {v8.8h}, [x3], x9 + st1 {v9.8h}, [x3], x9 + st1 {v10.8h}, [x3], x9 + st1 {v11.8h}, [x3], x9 + st1 {v12.8h}, [x3], x9 + st1 {v13.8h}, [x3], x9 + st1 {v14.8h}, [x3], x9 + st1 {v15.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + Relu8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + Write8: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.8h}, [x22], x13 + ld1 {v25.8h}, [x17], #16 + fmla v0.8h, v16.8h, v25.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v26.8h + Relu: + fmax v0.8h, v0.8h, v27.8h + Write: + st1 {v0.8h}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S new file mode 100644 index 00000000..2238257d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/ConvDwFp16Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Row(float16_t* output_ptr, const float16_t* input_ptr,const float16_t* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp16Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x9, x0 +mov x12, #2 // sizeof(float16_t) +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + +LoopInputDepth32In: + cmp x8, #32 + blt Loop8 + sub x8, x8, #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + cmp x8, #32 + blt LoopInputDepth32Out + LoopInputDepth32: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + sub x8, x8, #32 + cmp x8, #32 + bge LoopInputDepth32 + + LoopInputDepth32Out: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + Loop8: + cmp x8, #8 + blt L0 + + LoopInputDepth8: + ld1 {v0.8h}, [x6], #16 + ld1 {v2.8h}, [x7], #16 + ld1 {v16.8h}, [x0], #16 + fmla v16.8h, v0.8h, v2.8h + st1 {v16.8h}, [x9], #16 + sub x8, x8, #8 + cmp x8, #8 + bge LoopInputDepth8 + + L0: + cmp x8, #0 + beq Loop8LineEnd + + LoopInputDepth0: + ldr h0, [x6], #2 + ldr h1, [x7], #2 + ldr h2, [x0], #2 + fmul h0, h0, h1 + fadd h2, h2, h0 + str h2, [x9], #2 + subs x8, x8, #1 + bne LoopInputDepth0 + + Loop8LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S new file mode 100644 index 00000000..103985c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Border.S @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v1.8h}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.8h}, [x15] + ld1 {v2.8h}, [x16], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S new file mode 100644 index 00000000..44f0c1ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DeconvDwFp16Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.8h}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.8h}, [x21] + ld1 {v2.8h}, [x19], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S new file mode 100644 index 00000000..c27ade02 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/DynamicGatherArm64ForFp16.S @@ -0,0 +1,54 @@ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float16_t scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64ForFp16 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + + fcvtn v16.4h, v16.4s + fcvtn v17.4h, v17.4s + fcvtn v18.4h, v18.4s + fcvtn v19.4h, v19.4s + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x6], #32 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S new file mode 100644 index 00000000..39173658 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float16ToFloat32.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// x0: input, x1: output, x2: number +asm_function Float16ToFloat32 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + fcvtl v16.4s, v0.4h + fcvtl2 v17.4s, v0.8h + fcvtl v18.4s, v1.4h + fcvtl2 v19.4s, v1.8h + fcvtl v20.4s, v2.4h + fcvtl2 v21.4s, v2.8h + fcvtl v22.4s, v3.4h + fcvtl2 v23.4s, v3.8h + fcvtl v24.4s, v4.4h + fcvtl2 v25.4s, v4.8h + fcvtl v26.4s, v5.4h + fcvtl2 v27.4s, v5.8h + fcvtl v28.4s, v6.4h + fcvtl2 v29.4s, v6.8h + fcvtl v30.4s, v7.4h + fcvtl2 v31.4s, v7.8h + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr h0, [x0], #2 + fcvt s0, h0 + str s0, [x1], #4 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S new file mode 100644 index 00000000..b40a8aae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Float32ToFloat16.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void Float32ToFloat16(const float *input, float16_t output, int number); +// x0: input, x1: output, x2: number +asm_function Float32ToFloat16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64 + fcvtn v0.4h, v16.4s + fcvtn2 v0.8h, v17.4s + fcvtn v1.4h, v18.4s + fcvtn2 v1.8h, v19.4s + fcvtn v2.4h, v20.4s + fcvtn2 v2.8h, v21.4s + fcvtn v3.4h, v22.4s + fcvtn2 v3.8h, v23.4s + fcvtn v4.4h, v24.4s + fcvtn2 v4.8h, v25.4s + fcvtn v5.4h, v26.4s + fcvtn2 v5.8h, v27.4s + fcvtn v6.4h, v28.4s + fcvtn2 v6.8h, v29.4s + fcvtn v7.4h, v30.4s + fcvtn2 v7.8h, v31.4s + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr s0, [x0], #4 + fcvt h0, s0 + str h0, [x1], #2 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S new file mode 100644 index 00000000..c5aa798a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatVecMulFp16.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function MatVecMulFp16Neon64 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #2 // sizeof(float16) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + dup v9.8h, wzr + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth1_1x4 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + ld1 {v2.8h}, [x10], #16 + ld1 {v3.8h}, [x11], #16 + ld1 {v4.8h}, [x12], #16 + + fmla v5.8h, v1.8h, v0.8h + fmla v6.8h, v2.8h, v0.8h + fmla v7.8h, v3.8h, v0.8h + fmla v8.8h, v4.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x4 + b Depth8_1x4 + +Depth1_1x4: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + ld1 {v1.h}[1], [x10], #2 + ld1 {v1.h}[2], [x11], #2 + ld1 {v1.h}[3], [x12], #2 + + fmla v9.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v10.8h, v5.8h, v6.8h + faddp v11.8h, v7.8h, v8.8h + faddp v12.8h, v10.8h, v11.8h + faddp v13.8h, v12.8h, v12.8h + fadd v13.8h, v13.8h, v9.8h + + cbz x3, Act1x4 + ld1 {v14.4h}, [x3], #8 + fadd v13.8h, v13.8h, v14.8h + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v14.8h, #0x46, lsl #8 + fmin v13.8h, v13.8h, v14.8h + +Relu1x4: + dup v14.8h, wzr + fmax v13.8h, v13.8h, v14.8h + +Write1x4: + st1 {v13.4h}, [x2], #8 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + +Loop1x1: + dup v2.8h, wzr + dup v3.8h, wzr + dup v4.8h, wzr + dup v5.8h, wzr + dup v6.8h, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth1_1x1 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + + fmla v2.8h, v1.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + + fmla v3.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v4.8h, v2.8h, v2.8h + faddp v5.8h, v4.8h, v4.8h + faddp v6.8h, v5.8h, v5.8h + fadd v6.8h, v6.8h, v3.8h + + cbz x3, Act1x1 + ld1 {v7.h}[0], [x3], #2 + fadd v6.8h, v6.8h, v7.8h + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v7.8h, #0x46, lsl #8 + fmin v6.8h, v6.8h, v7.8h + +Relu1x1: + dup v7.8h, wzr + fmax v6.8h, v6.8h, v7.8h + +Write1x1: + st1 {v6.h}[0], [x2], #2 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S new file mode 100644 index 00000000..0af3589e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/Matmul12X16Fp16.S @@ -0,0 +1,1703 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type : ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu +// x5: depth : Ic +// x6: row : remain_row +// x7: col +// x8: stride : output_stride x8 = x8 * 2 +// x9: writeMode : OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +// x17 : input_stride + +asm_function MatMul12x16Fp16Opt + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + +.macro CLEAR_OUTPUT_V8_V9 + dup v8.4s, wzr + dup v9.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V11 + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V15 + CLEAR_OUTPUT_V8_V11 + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V23 + CLEAR_OUTPUT_V8_V15 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V31 + CLEAR_OUTPUT_V8_V23 + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +.endm + + mov x21, #24 + mul x17, x5, x21 // input_stride : 12 * Ic * sizeof(float16_t) + mov x21, #2 + mul x8, x8, x21 // output_stride + +LoopRowStart: + cmp x6, #1 + ble LoopRow1 + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow12: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol12: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V31 + cmp x19, #2 + blt LoopDepth12One + + LoopDepth12: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + fmla v24.8h, v4.8h, v1.h[0] + fmla v25.8h, v5.8h, v1.h[0] + fmla v26.8h, v4.8h, v1.h[1] + fmla v27.8h, v5.8h, v1.h[1] + fmla v28.8h, v4.8h, v1.h[2] + fmla v29.8h, v5.8h, v1.h[2] + fmla v30.8h, v4.8h, v1.h[3] + fmla v31.8h, v5.8h, v1.h[3] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + fmla v24.8h, v6.8h, v3.h[0] + fmla v25.8h, v7.8h, v3.h[0] + fmla v26.8h, v6.8h, v3.h[1] + fmla v27.8h, v7.8h, v3.h[1] + fmla v28.8h, v6.8h, v3.h[2] + fmla v29.8h, v7.8h, v3.h[2] + fmla v30.8h, v6.8h, v3.h[3] + fmla v31.8h, v7.8h, v3.h[3] + subs x19, x19, #2 + beq Bias12 + cmp x19, #2 + bge LoopDepth12 + + LoopDepth12One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + fmla v24.8h, v3.8h, v2.h[0] + fmla v25.8h, v4.8h, v2.h[0] + fmla v26.8h, v3.8h, v2.h[1] + fmla v27.8h, v4.8h, v2.h[1] + fmla v28.8h, v3.8h, v2.h[2] + fmla v29.8h, v4.8h, v2.h[2] + fmla v30.8h, v3.8h, v2.h[3] + fmla v31.8h, v4.8h, v2.h[3] + subs x19, x19, #1 + bgt LoopDepth12One + + Bias12: + cbz x3, Activation12 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v1.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v1.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v1.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v1.8h + + Activation12: + cmp x4, #3 + beq Relu612 + cmp x4, #1 + beq Relu12 + b Write + + Relu612: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu12: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol8: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V23 + cmp x19, #2 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + subs x19, x19, #2 + beq Bias8 + cmp x19, #2 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol4: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V15 + cmp x19, #2 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + subs x19, x19, #2 + beq Bias4 + cmp x19, #2 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol2: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V11 + cmp x19, #2 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + subs x19, x19, #2 + beq Bias2 + cmp x19, #2 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + b Write + +LoopRow1: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol1: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V9 + cmp x19, #2 + blt LoopDepth1One + + LoopDepth1: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + subs x19, x19, #2 + beq Bias1 + cmp x19, #2 + bge LoopDepth1 + + LoopDepth1One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + subs x19, x19, #1 + bgt LoopDepth1One + + Bias1: + cbz x3, Activation1 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + + Activation1: + cmp x4, #3 + beq Relu61 + cmp x4, #1 + beq Relu1 + b Write + + Relu61: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + + Relu1: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + cmp x13, #8 + beq Write8 + cmp x13, #9 + beq Write9 + cmp x13, #10 + beq Write10 + cmp x13, #11 + beq Write11 + cmp x13, #12 + beq Write12 + cmp x13, #13 + beq Write13 + cmp x13, #14 + beq Write14 + cmp x13, #15 + beq Write15 + b Write16 + + Write1: + add x2, x2, #2 + str h8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + + Write2: + add x2, x2, #4 + add x19, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x19], x8 + add x11, x11, #4 + b WriteEnd + + Write3: + add x2, x2, #6 + add x19, x11, #4 + add x20, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x20], x8 + st1 {v8.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x20], x8 + st1 {v10.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x20], x8 + st1 {v12.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x20], x8 + st1 {v14.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x20], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x20], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x20], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x20], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x20], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x20], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x20], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x20], x8 + st1 {v30.h}[2], [x19], x8 + add x11, x11, #6 + b WriteEnd + + Write4: + add x2, x2, #8 + st1 {v8.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + add x20, x11, #10 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x20, x11, #10 + add x10, x11, #12 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + st1 {v8.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + st1 {v10.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + st1 {v12.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + st1 {v14.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + st1 {v30.h}[6], [x10], x8 + add x11, x11, #14 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v8.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write9: + add x2, x2, #18 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + add x11, x11, #18 + b WriteEnd + Write10: + add x2, x2, #20 + add x19, x11, #16 + add x20, x11, #18 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + add x11, x11, #20 + b WriteEnd + Write11: + add x2, x2, #22 + add x19, x11, #16 + add x20, x11, #18 + add x10, x11, #20 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + st1 {v9.h}[2], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + st1 {v11.h}[2], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + st1 {v13.h}[2], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + st1 {v15.h}[2], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + st1 {v17.h}[2], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + st1 {v19.h}[2], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + st1 {v21.h}[2], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + st1 {v23.h}[2], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + st1 {v25.h}[2], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + st1 {v27.h}[2], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + st1 {v29.h}[2], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + st1 {v31.h}[2], [x10], x8 + add x11, x11, #22 + b WriteEnd + Write12: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write13: + add x2, x2, #26 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + add x11, x11, #26 + b WriteEnd + Write14: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + add x11, x11, #28 + b WriteEnd + Write15: + add x2, x2, #30 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + add x15, x11, #28 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + st1 {v9.h}[6], [x15], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + st1 {v11.h}[6], [x15], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + st1 {v13.h}[6], [x15], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + st1 {v15.h}[6], [x15], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + st1 {v17.h}[6], [x15], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + st1 {v19.h}[6], [x15], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + st1 {v21.h}[6], [x15], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + st1 {v23.h}[6], [x15], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + st1 {v25.h}[6], [x15], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + st1 {v27.h}[6], [x15], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + st1 {v29.h}[6], [x15], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + st1 {v31.h}[6], [x15], x8 + add x11, x11, #30 + b WriteEnd + Write16: + add x2, x2, #32 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.8h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.8h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.8h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.8h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x19], x8 + add x11, x11, #32 + b WriteEnd + + WriteEnd: + subs x13, x13, #16 // col - 16 + ble LoopColEnd + cmp x6, #1 + ble LoopCol1 + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol12 + +LoopColEnd: + add x0, x0, x17 + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S new file mode 100644 index 00000000..228e73ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulBaseFp16Neon.S @@ -0,0 +1,960 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulBaseFp16Neon + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] // act + add x8, x8, x8 // stride * sizeof(float16_t) + + add x16, x7, x7 // col * sizeof(float16_t) + add x17, x5, x5 // depth * zieof(float16_t) + mov x11, x2 + dup v12.8h, wzr + movi v13.8h, #0x46, lsl #8 +LoopRowStart: + cmp x6, #16 + bge LoopRow16 + cmp x6, #8 + bge LoopRow8 + b LoopRow4 + +LoopRow16: + mov x15, #16 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + mov v24.16b, v16.16b + mov v25.16b, v16.16b + mov v26.16b, v16.16b + mov v27.16b, v16.16b + mov v28.16b, v16.16b + mov v29.16b, v16.16b + mov v30.16b, v16.16b + mov v31.16b, v16.16b + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Activation16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + subs x19, x19, #1 + bgt LoopDepth16One + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write16 + Relu616: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + fmin v24.8h, v24.8h, v13.8h + fmin v25.8h, v25.8h, v13.8h + fmin v26.8h, v26.8h, v13.8h + fmin v27.8h, v27.8h, v13.8h + fmin v28.8h, v28.8h, v13.8h + fmin v29.8h, v29.8h, v13.8h + fmin v30.8h, v30.8h, v13.8h + fmin v31.8h, v31.8h, v13.8h + Relu16: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + fmax v24.8h, v24.8h, v12.8h + fmax v25.8h, v25.8h, v12.8h + fmax v26.8h, v26.8h, v12.8h + fmax v27.8h, v27.8h, v12.8h + fmax v28.8h, v28.8h, v12.8h + fmax v29.8h, v29.8h, v12.8h + fmax v30.8h, v30.8h, v12.8h + fmax v31.8h, v31.8h, v12.8h + Write16: + cmp x13, #8 + bge Write16x8 + b Write + Write16x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x11], x8 + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x11], x8 + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x11], x8 + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x11], x8 + b WriteEnd + +LoopRow8: + mov x15, #8 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + fmla v20.8h, v9.8h, v1.h[4] + fmla v21.8h, v9.8h, v1.h[5] + fmla v22.8h, v9.8h, v1.h[6] + fmla v23.8h, v9.8h, v1.h[7] + fmla v16.8h, v10.8h, v2.h[0] + fmla v17.8h, v10.8h, v2.h[1] + fmla v18.8h, v10.8h, v2.h[2] + fmla v19.8h, v10.8h, v2.h[3] + fmla v20.8h, v10.8h, v2.h[4] + fmla v21.8h, v10.8h, v2.h[5] + fmla v22.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v2.h[7] + fmla v16.8h, v11.8h, v3.h[0] + fmla v17.8h, v11.8h, v3.h[1] + fmla v18.8h, v11.8h, v3.h[2] + fmla v19.8h, v11.8h, v3.h[3] + fmla v20.8h, v11.8h, v3.h[4] + fmla v21.8h, v11.8h, v3.h[5] + fmla v22.8h, v11.8h, v3.h[6] + fmla v23.8h, v11.8h, v3.h[7] + subs x19, x19, #4 + beq Activation8 + cmp x19, #4 + bge LoopDepth8 + LoopDepth8One: + ld1 {v0.8h}, [x10], #16 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + subs x19, x19, #1 + bgt LoopDepth8One + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write8_Row + Relu68: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + Relu8: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + Write8_Row: + cmp x13, #8 // row + bge Write8x8 + b Write + Write8x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + b WriteEnd + +LoopRow4: + mov x15, #4 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + cmp x19, #4 + blt LoopDepth4One + LoopDepth4: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v16.8h, v9.8h, v0.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v18.8h, v9.8h, v0.h[6] + fmla v19.8h, v9.8h, v0.h[7] + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + fmla v16.8h, v11.8h, v1.h[4] + fmla v17.8h, v11.8h, v1.h[5] + fmla v18.8h, v11.8h, v1.h[6] + fmla v19.8h, v11.8h, v1.h[7] + subs x19, x19, #4 + beq Activation4 + cmp x19, #4 + bge LoopDepth4 + LoopDepth4One: + ld1 {v0.4h}, [x10], #8 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write4_Row + Relu64: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + Relu4: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + Write4_Row: + cmp x6, #4 + bge Write4x8 + b Write + Write4x8: + cmp x13, #8 + blt Write + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + b WriteEnd + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + st1 {v16.h}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.h}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.h}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.h}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.h}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.h}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.h}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.h}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.h}[0], [x11], x8 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #16 + bge LoopCol16 + cmp x6, #8 + bge LoopCol8 + b LoopCol4 + +LoopColEnd: + sub x2, x2, x16 // dst - col * 2 + mul x21, x8, x15 // row_block * col * 2 + add x2, x2, x21 + subs x6, x6, x15 + mul x15, x15, x17 + add x0, x0, x15 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S new file mode 100644 index 00000000..0f01e6f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16.S @@ -0,0 +1,892 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, int stride, bool write_nhwc) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: writeC8 + +asm_function MatmulFp16Neon64 + sub sp, sp, #144 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + + mov w18, #16 // sizeof(float16) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float16) * 8 * depth + mov x11, x3 // bias flag + mov x19, #2 + ldr x17, [sp, #144] + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp w13, #8 + blt CommLoopMul + +OptLoopMul8: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h, v9.8h}, [x16], #32 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v2.8h, v3.8h}, [x12], #32 + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v10.8h, v11.8h}, [x16], #32 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x16], #64 + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + ld1 {v4.8h, v5.8h}, [x12], #32 + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + ld1 {v6.8h, v7.8h}, [x12], #32 + fmla v16.8h, v12.8h, v0.h[0] + fmla v17.8h, v12.8h, v0.h[1] + fmla v18.8h, v12.8h, v0.h[2] + fmla v19.8h, v12.8h, v0.h[3] + fmla v20.8h, v12.8h, v0.h[4] + fmla v21.8h, v12.8h, v0.h[5] + fmla v22.8h, v12.8h, v0.h[6] + fmla v23.8h, v12.8h, v0.h[7] + fmla v24.8h, v12.8h, v1.h[0] + fmla v25.8h, v12.8h, v1.h[1] + fmla v26.8h, v12.8h, v1.h[2] + fmla v27.8h, v12.8h, v1.h[3] + fmla v28.8h, v12.8h, v1.h[4] + fmla v29.8h, v12.8h, v1.h[5] + fmla v30.8h, v12.8h, v1.h[6] + fmla v31.8h, v12.8h, v1.h[7] + fmla v16.8h, v13.8h, v2.h[0] + fmla v17.8h, v13.8h, v2.h[1] + fmla v18.8h, v13.8h, v2.h[2] + fmla v19.8h, v13.8h, v2.h[3] + fmla v20.8h, v13.8h, v2.h[4] + fmla v21.8h, v13.8h, v2.h[5] + fmla v22.8h, v13.8h, v2.h[6] + fmla v23.8h, v13.8h, v2.h[7] + fmla v24.8h, v13.8h, v3.h[0] + fmla v25.8h, v13.8h, v3.h[1] + fmla v26.8h, v13.8h, v3.h[2] + fmla v27.8h, v13.8h, v3.h[3] + fmla v28.8h, v13.8h, v3.h[4] + fmla v29.8h, v13.8h, v3.h[5] + fmla v30.8h, v13.8h, v3.h[6] + fmla v31.8h, v13.8h, v3.h[7] + fmla v16.8h, v14.8h, v4.h[0] + fmla v17.8h, v14.8h, v4.h[1] + fmla v18.8h, v14.8h, v4.h[2] + fmla v19.8h, v14.8h, v4.h[3] + fmla v20.8h, v14.8h, v4.h[4] + fmla v21.8h, v14.8h, v4.h[5] + fmla v22.8h, v14.8h, v4.h[6] + fmla v23.8h, v14.8h, v4.h[7] + fmla v24.8h, v14.8h, v5.h[0] + fmla v25.8h, v14.8h, v5.h[1] + fmla v26.8h, v14.8h, v5.h[2] + fmla v27.8h, v14.8h, v5.h[3] + fmla v28.8h, v14.8h, v5.h[4] + fmla v29.8h, v14.8h, v5.h[5] + fmla v30.8h, v14.8h, v5.h[6] + fmla v31.8h, v14.8h, v5.h[7] + fmla v16.8h, v15.8h, v6.h[0] + fmla v17.8h, v15.8h, v6.h[1] + fmla v18.8h, v15.8h, v6.h[2] + fmla v19.8h, v15.8h, v6.h[3] + fmla v20.8h, v15.8h, v6.h[4] + fmla v21.8h, v15.8h, v6.h[5] + fmla v22.8h, v15.8h, v6.h[6] + fmla v23.8h, v15.8h, v6.h[7] + fmla v24.8h, v15.8h, v7.h[0] + fmla v25.8h, v15.8h, v7.h[1] + fmla v26.8h, v15.8h, v7.h[2] + fmla v27.8h, v15.8h, v7.h[3] + fmla v28.8h, v15.8h, v7.h[4] + fmla v29.8h, v15.8h, v7.h[5] + fmla v30.8h, v15.8h, v7.h[6] + fmla v31.8h, v15.8h, v7.h[7] + + sub w13, w13, #8 + cmp w13, #0 + ble Bias + cmp w13, #8 + bge OptLoopMul8 + +CommLoopMul: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h}, [x16], #16 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + + subs w13, w13, #1 + bgt CommLoopMul + +Bias: + cbz x11, Activation + ld1 {v0.8h}, [x14], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v15.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v15.8h + fmin v17.8h, v17.8h, v15.8h + fmin v18.8h, v18.8h, v15.8h + fmin v19.8h, v19.8h, v15.8h + fmin v20.8h, v20.8h, v15.8h + fmin v21.8h, v21.8h, v15.8h + fmin v22.8h, v22.8h, v15.8h + fmin v23.8h, v23.8h, v15.8h + fmin v24.8h, v24.8h, v15.8h + fmin v25.8h, v25.8h, v15.8h + fmin v26.8h, v26.8h, v15.8h + fmin v27.8h, v27.8h, v15.8h + fmin v28.8h, v28.8h, v15.8h + fmin v29.8h, v29.8h, v15.8h + fmin v30.8h, v30.8h, v15.8h + fmin v31.8h, v31.8h, v15.8h + +Relu: + dup v14.4s, wzr + fmax v16.8h, v16.8h, v14.8h + fmax v17.8h, v17.8h, v14.8h + fmax v18.8h, v18.8h, v14.8h + fmax v19.8h, v19.8h, v14.8h + fmax v20.8h, v20.8h, v14.8h + fmax v21.8h, v21.8h, v14.8h + fmax v22.8h, v22.8h, v14.8h + fmax v23.8h, v23.8h, v14.8h + fmax v24.8h, v24.8h, v14.8h + fmax v25.8h, v25.8h, v14.8h + fmax v26.8h, v26.8h, v14.8h + fmax v27.8h, v27.8h, v14.8h + fmax v28.8h, v28.8h, v14.8h + fmax v29.8h, v29.8h, v14.8h + fmax v30.8h, v30.8h, v14.8h + fmax v31.8h, v31.8h, v14.8h + +Write: + ldrb w13, [sp, #152] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + st1 {v16.h}[0], [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + b WriteEnd +Write2: + add x13, x19, #2 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + b WriteEnd +Write3: + add x13, x19, #2 + add x14, x19, #4 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + st1 {v16.h}[2], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + st1 {v17.h}[2], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + st1 {v18.h}[2], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + st1 {v19.h}[2], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + st1 {v20.h}[2], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + st1 {v21.h}[2], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + st1 {v22.h}[2], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + st1 {v23.h}[2], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + st1 {v24.h}[2], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + st1 {v25.h}[2], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + st1 {v26.h}[2], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + st1 {v27.h}[2], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + st1 {v28.h}[2], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + st1 {v29.h}[2], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + st1 {v30.h}[2], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + st1 {v31.h}[2], [x14], x17 + b WriteEnd +Write4: + st1 {v16.4h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #8 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + b WriteEnd +Write6: + add x13, x19, #8 + add x14, x19, #10 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + b WriteEnd +Write7: + add x13, x19, #8 + add x14, x19, #10 + add x16, x19, #12 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + st1 {v16.h}[6], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + st1 {v17.h}[6], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + st1 {v18.h}[6], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + st1 {v19.h}[6], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + st1 {v20.h}[6], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + st1 {v21.h}[6], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + st1 {v22.h}[6], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + st1 {v23.h}[6], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + st1 {v24.h}[6], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + st1 {v25.h}[6], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + st1 {v26.h}[6], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + st1 {v27.h}[6], [x16], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + st1 {v28.h}[6], [x16], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + st1 {v29.h}[6], [x16], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + st1 {v30.h}[6], [x16], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + st1 {v31.h}[6], [x16], x17 + b WriteEnd +WriteC8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + b WriteEnd +Write8: + st1 {v16.8h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.8h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.8h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.8h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.8h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.8h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.8h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.8h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.8h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.8h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.8h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.8h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.8h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.8h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.8h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.8h}, [x19], x17 + +WriteEnd: + subs w10, w10, #16 // lhs row - 8 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #16 // bias ptr + stride + ldrb w13, [sp, #152] + cbz w13, NoDstStep + add x2, x2, #16 // dst ptr + stride +NoDstStep: + bgt L1 + +End1: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S new file mode 100644 index 00000000..c55e83a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16Opt.S @@ -0,0 +1,1185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16Neon64Opt + sub sp, sp, #96 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + + ldr x8, [sp, #96] + ldr x9, [sp, #104] + + mov x21, #32 // sizeof(float16_t) * 16 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float16_t) * 16 * depth + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #16 + mul x16, x6, x21 // row * 8 * sizeof(float16_t) +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #2 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float16_t) + mov x21, #16 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float16_t) +NoWinoSteps: + mov x21, #2 + mul x8, x8, x21 + +LoopRowStart: + cmp x6, #1 + ble LoopRow + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow16: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + cbz x9, NoReloadDst16 + mov x11, x2 + NoReloadDst16: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Bias16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + + subs x19, x19, #1 + bgt LoopDepth16One + + Bias16: + cbz x3, Activation16 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write + + Relu616: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu16: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + mov x11, x2 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + + subs x19, x19, #4 + beq Bias8 + cmp x19, #4 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + mov x11, x2 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + + cmp x19, #4 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + + subs x19, x19, #4 + beq Bias4 + cmp x19, #4 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol2: + cbz x9, NoReloadDst2 + mov x11, x2 + NoReloadDst2: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + + cmp x19, #4 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + + subs x19, x19, #4 + beq Bias2 + cmp x19, #4 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + b Write + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + mov x11, x2 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + + cmp x19, #4 + blt LoopDepthOne + + LoopDepth: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v16.8h, v10.8h, v4.h[0] + fmla v16.8h, v11.8h, v6.h[0] + + subs x19, x19, #4 + beq Bias + cmp x19, #4 + bge LoopDepth + + LoopDepthOne: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + + subs x19, x19, #1 + bgt LoopDepthOne + + Bias: + cbz x3, Activation + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + + Relu: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + + Write: + cmp x9, #2 + beq WriteWino + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + str h16, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h17, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h19, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h21, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h25, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h27, [x11] + cmp x6, #12 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #13 + beq WriteEnd + add x11, x11, x8 + str h29, [x11] + cmp x6, #14 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + cmp x6, #15 + beq WriteEnd + add x11, x11, x8 + str h31, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + add x11, x11, #4 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + add x11, x11, #6 + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + add x11, x11, #14 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x19], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x19], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x19], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v16.8h}, [x11], x15 + st1 {v17.8h}, [x11], x15 + st1 {v18.8h}, [x11], x15 + st1 {v19.8h}, [x11], x15 + st1 {v20.8h}, [x11], x15 + st1 {v21.8h}, [x11], x15 + st1 {v22.8h}, [x11], x15 + st1 {v23.8h}, [x11], x15 + st1 {v24.8h}, [x11], x15 + st1 {v25.8h}, [x11], x15 + st1 {v26.8h}, [x11], x15 + st1 {v27.8h}, [x11], x15 + st1 {v28.8h}, [x11], x15 + st1 {v29.8h}, [x11], x15 + st1 {v30.8h}, [x11], x15 + st1 {v31.8h}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + add x11, x11, #16 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #1 + ble LoopCol + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol16 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C8DstStep: + add x2, x2, #256 + mov x11, x2 + NoDstStep: + subs x6, x6, #16 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S new file mode 100644 index 00000000..545e0755 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulFp16OptV2.S @@ -0,0 +1,2966 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// size_t depth, size_t row, size_t col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16OptV2 + sub sp, sp, #192 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x29, x30, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] // writeMode + lsl x8, x8, #1 // stride * sizeof(float16_t) + + lsl x15, x7, #1 // col * sizeof(float16_t) + lsl x16, x5, #1 // depth * sizeof(float16_t) + mov x11, x2 + movi v7.8h, #0x46, lsl #8 + subs x6, x6, #12 + blt LoopRow8 +LoopRow12: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol12x8 + LoopCol12x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x16Enter + InitFromBias12x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + ld1 {v24.8h, v25.8h}, [x12] + ld1 {v26.8h, v27.8h}, [x12] + ld1 {v28.8h, v29.8h}, [x12] + ld1 {v30.8h, v31.8h}, [x12] + add x12, x12, #32 + Compute12x16Enter: + bl Compute12x16Unit + Activation12x16: + cmp x4, #3 + beq Relu612x16 + cmp x4, #1 + beq Relu12x16 + b Write12x16 + + Relu612x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + Relu12x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + Write12x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + st1 {v24.8h, v25.8h}, [x24], x8 + st1 {v26.8h, v27.8h}, [x24], x8 + st1 {v28.8h, v29.8h}, [x24], x8 + st1 {v30.8h, v31.8h}, [x24] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol12x16 + + LoopCol12x8: + adds x13, x13, #16 + cbz x13, LoopRow12End + subs x13, x13, #8 + blt LoopCol12x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + ld1 {v24.8h}, [x12] + ld1 {v26.8h}, [x12] + ld1 {v28.8h}, [x12] + ld1 {v30.8h}, [x12] + add x12, x12, #16 + Compute12x8Enter: + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b Write12x8 + + Relu612x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + + Relu12x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + Write12x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + st1 {v24.8h}, [x24], x8 + st1 {v26.8h}, [x24], x8 + st1 {v28.8h}, [x24], x8 + st1 {v30.8h}, [x24] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol12x4: + adds x13, x13, #8 + cbz x13, LoopRow12End + LoopCol12x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias12x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + dup v24.2s, wzr + dup v26.2s, wzr + dup v28.2s, wzr + dup v30.2s, wzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + ld1 {v24.4h}, [x12] + ld1 {v26.4h}, [x12] + ld1 {v28.4h}, [x12] + ld1 {v30.4h}, [x12] + add x12, x12, #8 + Compute12x4Enter: + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b Write12x4 + + Relu612x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + fmin v24.4h, v24.4h, v7.4h + fmin v26.4h, v26.4h, v7.4h + fmin v28.4h, v28.4h, v7.4h + fmin v30.4h, v30.4h, v7.4h + + Relu12x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + fmax v24.4h, v24.4h, v6.4h + fmax v26.4h, v26.4h, v6.4h + fmax v28.4h, v28.4h, v6.4h + fmax v30.4h, v30.4h, v6.4h + Write12x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + cmp x13, #1 + beq Write12x1 + cmp x13, #2 + beq Write12x2 + cmp x13, #3 + beq Write12x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + st1 {v24.4h}, [x24], x8 + st1 {v26.4h}, [x24], x8 + st1 {v28.4h}, [x24], x8 + st1 {v30.4h}, [x24] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol12x4Core + b LoopRow12End + Write12x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + st1 {v24.h}[0], [x24], x8 + st1 {v26.h}[0], [x24], x8 + st1 {v28.h}[0], [x24], x8 + st1 {v30.h}[0], [x24] + b LoopRow12End + Write12x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + st1 {v24.s}[0], [x24], x8 + st1 {v26.s}[0], [x24], x8 + st1 {v28.s}[0], [x24], x8 + st1 {v30.s}[0], [x24] + b LoopRow12End + Write12x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + st1 {v24.s}[0], [x22], x8 + st1 {v24.h}[2], [x23], x8 + st1 {v26.s}[0], [x22], x8 + st1 {v26.h}[2], [x23], x8 + st1 {v28.s}[0], [x22], x8 + st1 {v28.h}[2], [x23], x8 + st1 {v30.s}[0], [x22] + st1 {v30.h}[2], [x23] + LoopRow12End: + add x0, x0, x16, lsl #3 + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #3 + add x2, x2, x8, lsl #2 + subs x6, x6, #12 + bge LoopRow12 + +LoopRow8: + adds x6, x6,#12 + cbz x6, End + subs x6, x6, #8 + blt LoopRow4 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol8x8 + LoopCol8x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x16Enter + InitFromBias8x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + add x12, x12, #32 + Compute8x16Enter: + bl Compute8x16Unit + Activation8x16: + cmp x4, #3 + beq Relu68x16 + cmp x4, #1 + beq Relu8x16 + b Write8x16 + + Relu68x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + Relu8x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + Write8x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol8x16 + + LoopCol8x8: + adds x13, x13, #16 + cbz x13, LoopRow8End + subs x13, x13, #8 + blt LoopCol8x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + add x12, x12, #16 + Compute8x8Enter: + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b Write8x8 + + Relu68x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + + Relu8x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + Write8x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol8x4: + adds x13, x13, #8 + cbz x13, LoopRow8End + LoopCol8x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias8x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + add x12, x12, #8 + Compute8x4Enter: + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b Write8x4 + + Relu68x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + + Relu8x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + Write8x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + cmp x13, #1 + beq Write8x1 + cmp x13, #2 + beq Write8x2 + cmp x13, #3 + beq Write8x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol8x4Core + b LoopRow8End + Write8x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + b LoopRow8End + Write8x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + b LoopRow8End + Write8x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + LoopRow8End: + add x0, x0, x16, lsl #3 + add x2, x2, x8, lsl #3 + subs x6, x6, #8 + +LoopRow4: + adds x6, x6, #8 + cbz x6, End + subs x6, x6, #4 + blt LoopRowTail + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol4x8 + LoopCol4x16: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x16Enter + InitFromBias4x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + add x12, x12, #32 + Compute4x16Enter: + bl Compute4x16Unit + Activation4x16: + cmp x4, #3 + beq Relu64x16 + cmp x4, #1 + beq Relu4x16 + b Write4x16 + + Relu64x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + + Relu4x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + Write4x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol4x16 + + LoopCol4x8: + adds x13, x13, #16 + cbz x13, LoopRow4End + subs x13, x13, #8 + blt LoopCol4x4 + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + add x12, x12, #16 + Compute4x8Enter: + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b Write4x8 + + Relu64x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + + Relu4x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + Write4x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol4x4: + adds x13, x13, #8 + cbz x13, LoopRow4End + LoopCol4x4Core: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias4x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + add x12, x12, #8 + Compute4x4Enter: + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b Write4x4 + + Relu64x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + + Relu4x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + Write4x4: + mov x22, x21 + cmp x13, #1 + beq Write4x1 + cmp x13, #2 + beq Write4x2 + cmp x13, #3 + beq Write4x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol4x4Core + b LoopRow4End + Write4x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + b LoopRow4End + Write4x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + b LoopRow4End + Write4x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + LoopRow4End: + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #2 + subs x6, x6, #4 + +LoopRowTail: + adds x6, x6, #4 + cbz x6, End + cmp x6, #1 + beq LoopRow1 + cmp x6, #2 + beq LoopRow2 + // LoopRow3 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol3x8 + LoopCol3x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + b Compute3x16Enter + InitFromBias3x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + add x12, x12, #32 + Compute3x16Enter: + bl Compute3x16Unit + Activation3x16: + cmp x4, #3 + beq Relu63x16 + cmp x4, #1 + beq Relu3x16 + b Write3x16 + + Relu63x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + + Relu3x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + Write3x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol3x16 + + LoopCol3x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol3x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + b Compute3x8Enter + InitFromBias3x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + add x12, x12, #16 + Compute3x8Enter: + bl Compute3x8Unit + Activation3x8: + cmp x4, #3 + beq Relu63x8 + cmp x4, #1 + beq Relu3x8 + b Write3x8 + + Relu63x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + + Relu3x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + Write3x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol3x4: + adds x13, x13, #8 + cbz x13, End + LoopCol3x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + b Compute3x4Enter + InitFromBias3x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + add x12, x12, #8 + Compute3x4Enter: + bl Compute3x4Unit + Activation3x4: + cmp x4, #3 + beq Relu63x4 + cmp x4, #1 + beq Relu3x4 + b Write3x4 + + Relu63x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + + Relu3x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + Write3x4: + mov x22, x21 + cmp x13, #1 + beq Write3x1 + cmp x13, #2 + beq Write3x2 + cmp x13, #3 + beq Write3x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol3x4Core + b End + Write3x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22] + b End + Write3x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22] + b End + Write3x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + b End + +LoopRow2: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol2x8 + LoopCol2x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + b Compute2x16Enter + InitFromBias2x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + add x12, x12, #32 + Compute2x16Enter: + bl Compute2x16Unit + Activation2x16: + cmp x4, #3 + beq Relu62x16 + cmp x4, #1 + beq Relu2x16 + b Write2x16 + + Relu62x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + + Relu2x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + Write2x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol2x16 + + LoopCol2x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol2x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x8 + dup v8.2d, xzr + dup v10.2d, xzr + b Compute2x8Enter + InitFromBias2x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + add x12, x12, #16 + Compute2x8Enter: + bl Compute2x8Unit + Activation2x8: + cmp x4, #3 + beq Relu62x8 + cmp x4, #1 + beq Relu2x8 + b Write2x8 + + Relu62x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + + Relu2x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + Write2x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol2x4: + adds x13, x13, #8 + cbz x13, End + LoopCol2x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x4 + dup v8.2s, wzr + dup v10.2s, wzr + b Compute2x4Enter + InitFromBias2x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + add x12, x12, #8 + Compute2x4Enter: + bl Compute2x4Unit + Activation2x4: + cmp x4, #3 + beq Relu62x4 + cmp x4, #1 + beq Relu2x4 + b Write2x4 + + Relu62x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + Relu2x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + Write2x4: + mov x22, x21 + cmp x13, #1 + beq Write2x1 + cmp x13, #2 + beq Write2x2 + cmp x13, #3 + beq Write2x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol2x4Core + b End + Write2x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22] + b End + Write2x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22] + b End + Write2x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + b End + +LoopRow1: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol1x8 + LoopCol1x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x16 + dup v8.2d, xzr + dup v9.2d, xzr + b Compute1x16Enter + InitFromBias1x16: + ld1 {v8.8h, v9.8h}, [x12], #32 + Compute1x16Enter: + bl Compute1x16Unit + Activation1x16: + cmp x4, #3 + beq Relu61x16 + cmp x4, #1 + beq Relu1x16 + b Write1x16 + + Relu61x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + + Relu1x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + Write1x16: + st1 {v8.8h, v9.8h}, [x21], #32 + subs x13, x13, #16 + bge LoopCol1x16 + + LoopCol1x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol1x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x8 + dup v8.2d, xzr + b Compute1x8Enter + InitFromBias1x8: + ld1 {v8.8h}, [x12], #16 + Compute1x8Enter: + bl Compute1x8Unit + Activation1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Write1x8 + + Relu61x8: + fmin v8.8h, v8.8h, v7.8h + + Relu1x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + Write1x8: + st1 {v8.8h}, [x21], #16 + subs x13, x13, #8 + + LoopCol1x4: + adds x13, x13, #8 + cbz x13, End + LoopCol1x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x4 + dup v8.2s, wzr + b Compute1x4Enter + InitFromBias1x4: + ld1 {v8.4h}, [x12], #8 + Compute1x4Enter: + bl Compute1x4Unit + Activation1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Write1x4 + + Relu61x4: + fmin v8.4h, v8.4h, v7.4h + Relu1x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + Write1x4: + cmp x13, #1 + beq Write1x1 + cmp x13, #2 + beq Write1x2 + cmp x13, #3 + beq Write1x3 + st1 {v8.4h}, [x21], #8 + subs x13, x13, #4 + bgt LoopCol1x4Core + b End + Write1x1: + st1 {v8.h}[0], [x21] + b End + Write1x2: + st1 {v8.s}[0], [x21] + b End + Write1x3: + add x22, x21, #4 + st1 {v8.s}[0], [x21] + st1 {v8.h}[2], [x22] + b End + +Compute12x16Unit: + subs x14, x14, #2 + ble Compute12x16End + Compute12x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + + fmla v8.8h, v5.8h, v1.h[4] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v1.h[6] + fmla v14.8h, v5.8h, v1.h[7] + fmla v16.8h, v5.8h, v2.h[0] + fmla v18.8h, v5.8h, v2.h[1] + fmla v20.8h, v5.8h, v2.h[2] + fmla v22.8h, v5.8h, v2.h[3] + fmla v24.8h, v5.8h, v2.h[4] + fmla v26.8h, v5.8h, v2.h[5] + fmla v28.8h, v5.8h, v2.h[6] + fmla v30.8h, v5.8h, v2.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[4] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v1.h[6] + fmla v15.8h, v6.8h, v1.h[7] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v2.h[0] + fmla v19.8h, v6.8h, v2.h[1] + fmla v21.8h, v6.8h, v2.h[2] + fmla v23.8h, v6.8h, v2.h[3] + fmla v25.8h, v6.8h, v2.h[4] + fmla v27.8h, v6.8h, v2.h[5] + fmla v29.8h, v6.8h, v2.h[6] + fmla v31.8h, v6.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x16 + Compute12x16End: + cbnz x14, Compute12x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + mov v0.16b, v2.16b + Compute12x16End1: + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + ret + +Compute12x8Unit: + subs x14, x14, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[4] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v1.h[6] + fmla v14.8h, v4.8h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v2.h[0] + fmla v18.8h, v4.8h, v2.h[1] + fmla v20.8h, v4.8h, v2.h[2] + fmla v22.8h, v4.8h, v2.h[3] + fmla v24.8h, v4.8h, v2.h[4] + fmla v26.8h, v4.8h, v2.h[5] + fmla v28.8h, v4.8h, v2.h[6] + fmla v30.8h, v4.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x8 + Compute12x8End: + cbnz x14, Compute12x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.16b, v4.16b + Compute12x8End1: + ld1 {v1.4h}, [x10] + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ret + +Compute12x4Unit: + subs x14, x14, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[4] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v1.h[6] + fmla v14.4h, v4.4h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v2.h[0] + fmla v18.4h, v4.4h, v2.h[1] + fmla v20.4h, v4.4h, v2.h[2] + fmla v22.4h, v4.4h, v2.h[3] + fmla v24.4h, v4.4h, v2.h[4] + fmla v26.4h, v4.4h, v2.h[5] + fmla v28.4h, v4.4h, v2.h[6] + fmla v30.4h, v4.4h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x4 + Compute12x4End: + cbnz x14, Compute12x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.8b, v4.8b + Compute12x4End1: + ld1 {v1.4h}, [x10] + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ret + +Compute8x16Unit: + subs x14, x14, #2 + ble Compute8x16End + Compute8x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + fmla v16.8h, v5.8h, v1.h[4] + fmla v18.8h, v5.8h, v1.h[5] + fmla v20.8h, v5.8h, v1.h[6] + fmla v22.8h, v5.8h, v1.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v1.h[4] + fmla v19.8h, v6.8h, v1.h[5] + fmla v21.8h, v6.8h, v1.h[6] + fmla v23.8h, v6.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x16 + Compute8x16End: + cbnz x14, Compute8x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + mov v0.16b, v1.16b + Compute8x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + ret + +Compute8x8Unit: + subs x14, x14, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v1.h[4] + fmla v18.8h, v4.8h, v1.h[5] + fmla v20.8h, v4.8h, v1.h[6] + fmla v22.8h, v4.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x8 + Compute8x8End: + cbnz x14, Compute8x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + mov v0.16b, v1.16b + mov v3.16b, v4.16b + Compute8x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + ret + +Compute8x4Unit: + subs x14, x14, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v1.h[4] + fmla v18.4h, v4.4h, v1.h[5] + fmla v20.4h, v4.4h, v1.h[6] + fmla v22.4h, v4.4h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x4 + Compute8x4End: + cbnz x14, Compute8x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + mov v0.16b, v1.16b + mov v3.8b, v4.8b + Compute8x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + ret + +Compute4x16Unit: + subs x14, x14, #2 + ble Compute4x16End + Compute4x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x16 + Compute4x16End: + cbnz x14, Compute4x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + mov v0.8b, v1.8b + Compute4x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ret + +Compute4x8Unit: + subs x14, x14, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x8 + Compute4x8End: + cbnz x14, Compute4x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + mov v0.8b, v1.8b + mov v3.16b, v4.16b + Compute4x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + ret + +Compute4x4Unit: + subs x14, x14, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x4 + Compute4x4End: + cbnz x14, Compute4x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + mov v0.8b, v1.8b + mov v3.8b, v4.8b + Compute4x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + ret + +Compute3x16Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x16End4 + Compute3x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v13.8h, v4.8h, v2.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v2.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v2.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + fmla v12.8h, v3.8h, v2.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v13.8h, v4.8h, v2.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v12.8h, v5.8h, v2.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + fmla v13.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x16 + Compute3x16End4: + adds x14, x14, #8 + cbz x14, Compute3x16Return + subs x14, x14, #4 + blt Compute3x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x16EndTail: + adds x14, x14, #4 + cbz x14, Compute3x16Return + cmp x14, #1 + beq Compute3x16EndTail1 + cmp x14, #2 + beq Compute3x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + b Compute3x16Return + Compute3x16EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + b Compute3x16Return + Compute3x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + Compute3x16Return: + ret + +Compute3x8Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x8End4 + Compute3x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v2.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v12.8h, v5.8h, v2.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + fmla v12.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x8 + Compute3x8End4: + adds x14, x14, #8 + cbz x14, Compute3x8Return + subs x14, x14, #4 + blt Compute3x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x8EndTail: + adds x14, x14, #4 + cbz x14, Compute3x8Return + cmp x14, #1 + beq Compute3x8EndTail1 + cmp x14, #2 + beq Compute3x8EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + b Compute3x8Return + Compute3x8EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v5.8h, v0.h[0] + fmla v10.8h, v5.8h, v1.h[0] + fmla v12.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[1] + fmla v10.8h, v6.8h, v1.h[1] + fmla v12.8h, v6.8h, v3.h[0] + b Compute3x8Return + Compute3x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + Compute3x8Return: + ret + +Compute3x4Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x4End4 + Compute3x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + fmla v12.4h, v3.4h, v2.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v2.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v12.4h, v5.4h, v2.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + fmla v12.4h, v6.4h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x4 + Compute3x4End4: + adds x14, x14, #8 + cbz x14, Compute3x4Return + subs x14, x14, #4 + blt Compute3x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + subs x14, x14, #4 + Compute3x4EndTail: + adds x14, x14, #4 + cbz x14, Compute3x4Return + cmp x14, #1 + beq Compute3x4EndTail1 + cmp x14, #2 + beq Compute3x4EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + b Compute3x4Return + Compute3x4EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v5.4h, v0.h[0] + fmla v10.4h, v5.4h, v1.h[0] + fmla v12.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[1] + fmla v10.4h, v6.4h, v1.h[1] + fmla v12.4h, v6.4h, v3.h[0] + b Compute3x4Return + Compute3x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + Compute3x4Return: + ret + +Compute2x16Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x16End4 + Compute2x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x16 + Compute2x16End4: + adds x14, x14, #8 + cbz x14, Compute2x16Return + subs x14, x14, #4 + blt Compute2x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x16EndTail: + adds x14, x14, #4 + cbz x14, Compute2x16Return + cmp x14, #1 + beq Compute2x16EndTail1 + cmp x14, #2 + beq Compute2x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.s}[0], [x19], #4 + ld1 {v1.h}[2], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + b Compute2x16Return + Compute2x16EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v2.h[0] + b Compute2x16Return + Compute2x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + Compute2x16Return: + ret + +Compute2x8Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x8End4 + Compute2x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x8 + Compute2x8End4: + adds x14, x14, #8 + cbz x14, Compute2x8Return + subs x14, x14, #4 + blt Compute2x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x8EndTail: + adds x14, x14, #4 + cbz x14, Compute2x8Return + cmp x14, #1 + beq Compute2x8EndTail1 + cmp x14, #2 + beq Compute2x8EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[0] + fmla v10.8h, v4.8h, v1.h[0] + ld1 {v6.8h}, [x11], #16 + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[2] + fmla v10.8h, v6.8h, v3.h[0] + b Compute2x8Return + Compute2x8EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v2.h[0] + b Compute2x8Return + Compute2x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + Compute2x8Return: + ret + +Compute2x4Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x4End4 + Compute2x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x4 + Compute2x4End4: + adds x14, x14, #8 + cbz x14, Compute2x4Return + subs x14, x14, #4 + blt Compute2x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + subs x14, x14, #4 + Compute2x4EndTail: + adds x14, x14, #4 + cbz x14, Compute2x4Return + cmp x14, #1 + beq Compute2x4EndTail1 + cmp x14, #2 + beq Compute2x4EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.4h, v5.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[0] + fmla v10.4h, v4.4h, v1.h[0] + ld1 {v6.4h}, [x11], #8 + fmla v8.4h, v5.4h, v0.h[1] + fmla v10.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[2] + fmla v10.4h, v6.4h, v3.h[0] + b Compute2x4Return + Compute2x4EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v2.h[0] + b Compute2x4Return + Compute2x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + Compute2x4Return: + ret + +Compute1x16Unit: + subs x14, x14, #8 + blt Compute1x16End4 + Compute1x16: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v8.8h, v5.8h, v0.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v8.8h, v3.8h, v0.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v9.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x16 + Compute1x16End4: + adds x14, x14, #8 + cbz x14, Compute1x16Return + subs x14, x14, #4 + blt Compute1x16EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v9.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x16EndTail: + adds x14, x14, #4 + cbz x14, Compute1x16Return + cmp x14, #1 + beq Compute1x16EndTail1 + cmp x14, #2 + beq Compute1x16EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v1.h[0] + fmla v8.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v2.h[0] + b Compute1x16Return + Compute1x16EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + fmla v9.8h, v6.8h, v1.h[0] + b Compute1x16Return + Compute1x16EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + Compute1x16Return: + ret + +Compute1x8Unit: + subs x14, x14, #8 + blt Compute1x8End4 + Compute1x8: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v8.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x8 + Compute1x8End4: + adds x14, x14, #8 + cbz x14, Compute1x8Return + subs x14, x14, #4 + blt Compute1x8EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v8.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x8EndTail: + adds x14, x14, #4 + cbz x14, Compute1x8Return + cmp x14, #1 + beq Compute1x8EndTail1 + cmp x14, #2 + beq Compute1x8EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v2.h[0] + b Compute1x8Return + Compute1x8EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v8.8h, v4.8h, v1.h[0] + b Compute1x8Return + Compute1x8EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + Compute1x8Return: + ret + +Compute1x4Unit: + subs x14, x14, #8 + blt Compute1x4End4 + Compute1x4: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v8.4h, v3.4h, v0.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v8.4h, v6.4h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x4 + Compute1x4End4: + adds x14, x14, #8 + cbz x14, Compute1x4Return + subs x14, x14, #4 + blt Compute1x4EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v8.4h, v6.4h, v0.h[3] + subs x14, x14, #4 + Compute1x4EndTail: + adds x14, x14, #4 + cbz x14, Compute1x4Return + cmp x14, #1 + beq Compute1x4EndTail1 + cmp x14, #2 + beq Compute1x4EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v8.4h, v5.4h, v2.h[0] + b Compute1x4Return + Compute1x4EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v8.4h, v4.4h, v1.h[0] + b Compute1x4Return + Compute1x4EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + Compute1x4Return: + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S new file mode 100644 index 00000000..ac135170 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/MatmulWinogradFp16.S @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel +asm_function MatrixMultiplyWinogradFp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.8h}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + + mov x8, #2 + mul x10, x5, x8 // n * 2 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 2 + mul x21, x13, x4 // in_channel * k * 2 + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x13 // x22 * channel_in * 2 + add x20, x2, x22 // dst + offset + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC32: + mov x12, x14 + mov x9, x4 // new_k + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + LoopK32: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + fmla v7.8h, v2.8h, v4.h[0] + fmla v8.8h, v3.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK32 + Write32: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + st1 {v7.8h}, [x20], #16 + st1 {v8.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #32 + beq EndLoopC + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC16: + dup v5.8h, wzr + dup v6.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK16: + ld1 {v0.8h, v1.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 32B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.8h}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.4h, v0.4h, v4.h[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4h}, [x20], #8 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #8 // add 8B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr h0, [x16] + add x16, x16, x13 + ldr h4, [x12] + add x12, x12, x10 + fmul h0, h0, h4 + fadd h5, h5, h0 + subs x9, x9, #1 + bne LoopK + Write: + str h5, [x20], #2 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #2 // ptr add 2B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #2 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + + EndLoopM: + ld1 {v8.8h}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S new file mode 100644 index 00000000..82fff430 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC4Fp16.S @@ -0,0 +1,293 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +asm_function PostFuncBiasReluC4Fp16 + + movi v26.4h, #6 + scvtf v26.4h, v26.4h + dup v27.4h, wzr + + mov x10, #2 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #2 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x1], #32 + + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + fadd v4.4h, v4.4h, v16.4h + fadd v5.4h, v5.4h, v16.4h + fadd v6.4h, v6.4h, v16.4h + fadd v7.4h, v7.4h, v16.4h + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h + fmin v4.4h, v4.4h, v26.4h + fmin v5.4h, v5.4h, v26.4h + fmin v6.4h, v6.4h, v26.4h + fmin v7.4h, v7.4h, v26.4h +Relu_8x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h + fmax v4.4h, v4.4h, v27.4h + fmax v5.4h, v5.4h, v27.4h + fmax v6.4h, v6.4h, v27.4h + fmax v7.4h, v7.4h, v27.4h +Write_8x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + st1 {v4.4h}, [x15], x12 + st1 {v5.4h}, [x15], x12 + st1 {v6.4h}, [x15], x12 + st1 {v7.4h}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h +Relu_4x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h +Write_4x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.4h}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp w4, #0 + beq End + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + mov x15, #2 + mul x14, x10, x15 + add x0, x0, x14 + + cmp w4, #1 + beq Loop_C1_1 + cmp w4, #2 + beq Loop_C1_2 + cmp w4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #4 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S new file mode 100644 index 00000000..c339ac8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/PostFuncBiasReluC8Fp16.S @@ -0,0 +1,469 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8Fp16(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x22 x23 x24 x25 write loop tmp buf +// x26 relu6 #6; x27 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8Fp16 + movi v26.8h, #0x46, lsl #8 + dup v27.8h, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x25, #2 + mul x24, x10, x25 + add x25, x0, x24 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + +Loop8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + fadd v4.8h, v4.8h, v16.8h + fadd v5.8h, v5.8h, v16.8h + fadd v6.8h, v6.8h, v16.8h + fadd v7.8h, v7.8h, v16.8h + + cmp w7, #2 + beq Relu6_8x8 + cmp w7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h +Relu_8x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h +Write_8x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + st1 {v4.8h}, [x25], x6 + st1 {v5.8h}, [x25], x6 + st1 {v6.8h}, [x25], x6 + st1 {v7.8h}, [x25], x6 + b Loop8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + + cmp w7, #2 + beq Relu6_4x8 + cmp w7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h +Relu_4x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h +Write_4x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + b Loop_4x8 + +Loop_1x8: + cmp w7, #2 + beq Relu6_1x8 + cmp w7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.8h}, [x25], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + mov x25, #2 + mul x24, x10, x25 + add x22, x0, x24 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp w7, #2 + beq Loop_C1_1_Relu6 + cmp w7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Write + +Loop_C1_2: + add x24, x0, #2 + cmp w7, #2 + beq Loop_C1_2_Relu6 + cmp w7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x24, x22, #2 + add x25, x22, #4 + cmp w7, #2 + beq Loop_C1_3_Relu6 + cmp w7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp w7, #2 + beq Loop_C1_4_Relu6 + cmp w7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x25, x22, #8 + cmp w7, #2 + beq Loop_C1_5_Relu6 + cmp w7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x23, x22, #8 + add x24, x22, #10 + cmp w7, #2 + beq Loop_C1_6_Relu6 + cmp w7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x23, x22, #8 + add x24, x22, #10 + add x25, x22, #12 + cmp w7, #2 + beq Loop_C1_7_Relu6 + cmp w7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Write + +End: + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S new file mode 100644 index 00000000..e0f2211e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/TiledC4MatmulFp16.S @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #2 //sizeof(float) +mul x3, x3, x7 +mov x7, #32 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + fmul v16.4h, v8.4h, v0.h[0] + fmul v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmul v18.4h, v8.4h, v2.h[0] + fmul v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmul v20.4h, v8.4h, v4.h[0] + fmul v21.4h, v8.4h, v5.h[0] + fmul v22.4h, v8.4h, v6.h[0] + fmul v23.4h, v8.4h, v7.h[0] + fmul v24.4h, v12.4h, v0.h[0] + fmul v25.4h, v12.4h, v1.h[0] + fmul v26.4h, v12.4h, v2.h[0] + fmul v27.4h, v12.4h, v3.h[0] + fmul v28.4h, v12.4h, v4.h[0] + fmul v29.4h, v12.4h, v5.h[0] + fmul v30.4h, v12.4h, v6.h[0] + fmul v31.4h, v12.4h, v7.h[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #64 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #64 + prfm pldl1keep, [x8, #64] + prfm pldl1keep, [x8, #96] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + fmla v29.4h, v15.4h, v5.h[3] + fmla v30.4h, v15.4h, v6.h[3] + fmla v31.4h, v15.4h, v7.h[3] + + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + fmla v24.4h, v12.4h, v0.h[0] + fmla v25.4h, v12.4h, v1.h[0] + fmla v26.4h, v12.4h, v2.h[0] + fmla v27.4h, v12.4h, v3.h[0] + fmla v28.4h, v12.4h, v4.h[0] + fmla v29.4h, v12.4h, v5.h[0] + fmla v30.4h, v12.4h, v6.h[0] + fmla v31.4h, v12.4h, v7.h[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + add x7, x0, #32 + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x3 + fmla v29.4h, v15.4h, v5.h[3] + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x7], x3 + fmla v30.4h, v15.4h, v6.h[3] + st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], x3 + mov x2, x6 + fmla v31.4h, v15.4h, v7.h[3] + st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S new file mode 100644 index 00000000..bf1803c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/VecMatmulFp16.S @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +// void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function VecMatmulFp16Neon64_2 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + +LoopCol: + mov x15, x0 // reload a ptr + ld1 {v0.8h}, [x3], #16 // acc0 + ld1 {v1.8h}, [x3], #16 // acc1 + mov w9, #0 // tmp depth + +Loop2x8Inner: + sub w18, w5, w9 + cmp w18, #8 + blt DepthRemain + + ld1 {v2.8h}, [x15], #16 // a + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x1], #64 + ld1 {v7.8h, v8.8h, v9.8h, v10.8h}, [x1], #64 + ld1 {v11.8h, v12.8h, v13.8h, v14.8h}, [x1], #64 + ld1 {v15.8h, v16.8h, v17.8h, v18.8h}, [x1], #64 + + fmla v0.8h, v3.8h, v2.h[0] + fmla v0.8h, v5.8h, v2.h[1] + fmla v0.8h, v7.8h, v2.h[2] + fmla v0.8h, v9.8h, v2.h[3] + fmla v0.8h, v11.8h, v2.h[4] + fmla v0.8h, v13.8h, v2.h[5] + fmla v0.8h, v15.8h, v2.h[6] + fmla v0.8h, v17.8h, v2.h[7] + fmla v1.8h, v4.8h, v2.h[0] + fmla v1.8h, v6.8h, v2.h[1] + fmla v1.8h, v8.8h, v2.h[2] + fmla v1.8h, v10.8h, v2.h[3] + fmla v1.8h, v12.8h, v2.h[4] + fmla v1.8h, v14.8h, v2.h[5] + fmla v1.8h, v16.8h, v2.h[6] + fmla v1.8h, v18.8h, v2.h[7] + + add w9, w9, #8 + b Loop2x8Inner + +DepthRemain: // last depth [0, 8) + cmp w18, #0 + ble Act + ld1 {v2.h}[0], [x15], #2 + ld1 {v3.8h}, [x1], #16 + ld1 {v4.8h}, [x1], #16 + fmla v0.8h, v3.8h, v2.h[0] + fmla v1.8h, v4.8h, v2.h[0] + sub w18, w18, #1 + b DepthRemain + +Act: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v19.8h, #0x46, lsl #8 + fmin v0.8h, v0.8h, v19.8h + fmin v1.8h, v1.8h, v19.8h + +Relu: + dup v20.8h, wzr + fmax v0.8h, v0.8h, v20.8h + fmax v1.8h, v1.8h, v20.8h + +Write: + cmp w6, #8 + blt WriteMod8 + st1 {v0.8h}, [x2], #16 + sub w6, w6, #8 + mov v0.16b, v1.16b + cmp w6, #8 + blt WriteMod8 + st1 {v1.8h}, [x2], #16 + sub w6, w6, #8 + cbz w6, End + b LoopCol + +WriteMod8: + cmp w6, #0 + ble End + cmp w6, #1 + beq Write1 + cmp w6, #2 + beq Write2 + cmp w6, #3 + beq Write3 + cmp w6, #4 + beq Write4 + cmp w6, #5 + beq Write5 + cmp w6, #6 + beq Write6 + cmp w6, #7 + beq Write7 + +Write1: + st1 {v0.h}[0], [x2], #2 + b End +Write2: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + b End +Write3: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + b End +Write4: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + b End +Write5: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + b End +Write6: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + b End +Write7: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + st1 {v0.h}[6], [x2], #2 + b End + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S new file mode 100644 index 00000000..c308eb34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransLeftFp16.S @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeftFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + ld1 {v0.h}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4h}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x14], #8 + fmla v0.4h, v1.4h, v31.4h + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #2 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S new file mode 100644 index 00000000..cde99cc1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/fp16/WinogradTransRightFp16.S @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRightFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + ld1 {v0.h}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4h}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x17], #8 + fmla v0.4h, v1.4h, v31.4h + + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #2 + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + + ret + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S new file mode 100644 index 00000000..106eba38 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S @@ -0,0 +1,764 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, +// float *bias, size_t row, size_t col, size_t stride, const int *a_sums, +// const int *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWI + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b AddBias + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b AddBias + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b AddBias + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x5] + + fadd v16.4s,v16.4s,v1.4s + fadd v17.4s,v17.4s,v2.4s + fadd v18.4s,v18.4s,v3.4s + fadd v19.4s,v19.4s,v4.4s + + fadd v20.4s,v20.4s,v1.4s + fadd v21.4s,v21.4s,v2.4s + fadd v22.4s,v22.4s,v3.4s + fadd v23.4s,v23.4s,v4.4s + + fadd v24.4s,v24.4s,v1.4s + fadd v25.4s,v25.4s,v2.4s + fadd v26.4s,v26.4s,v3.4s + fadd v27.4s,v27.4s,v4.4s + + fadd v28.4s,v28.4s,v1.4s + fadd v29.4s,v29.4s,v2.4s + fadd v30.4s,v30.4s,v3.4s + fadd v31.4s,v31.4s,v4.4s + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4s, wzr + + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + b StoreData + +Relu6: + dup v1.4s, wzr + movi v2.4s, #6 + scvtf v2.4s, v2.4s + + // max (out, 0) + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + // min (out, 6) + + smin v16.4s,v16.4s,v2.4s + smin v17.4s,v17.4s,v2.4s + smin v18.4s,v18.4s,v2.4s + smin v19.4s,v19.4s,v2.4s + + smin v20.4s,v20.4s,v2.4s + smin v21.4s,v21.4s,v2.4s + smin v22.4s,v22.4s,v2.4s + smin v23.4s,v23.4s,v2.4s + + smin v24.4s,v24.4s,v2.4s + smin v25.4s,v25.4s,v2.4s + smin v26.4s,v26.4s,v2.4s + smin v27.4s,v27.4s,v2.4s + + smin v28.4s,v28.4s,v2.4s + smin v29.4s,v29.4s,v2.4s + smin v30.4s,v30.4s,v2.4s + smin v31.4s,v31.4s,v2.4s + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2], x8 + st1 {v28.4s,v29.4s,v30.4s,v31.4s}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15], #8 + st1 {v19.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14], #8 + st1 {v23.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13], #8 + st1 {v27.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12], #8 + st1 {v31.s}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12] + b StoreDataEnd + +Write13: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + b StoreDataEnd + +Write11: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15], #8 + st1 {v18.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14], #8 + st1 {v22.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13], #8 + st1 {v26.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12], #8 + st1 {v30.s}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12] + b StoreDataEnd + +Write9: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4s,v17.4s}, [x15], #32 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + b StoreDataEnd + +Write7: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15], #8 + st1 {v17.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14], #8 + st1 {v21.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13], #8 + st1 {v25.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12], #8 + st1 {v29.s}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12] + b StoreDataEnd + +Write5: + st1 {v16.4s}, [x15], #16 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4s}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.1d}, [x15], #8 + st1 {v16.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14], #8 + st1 {v20.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13], #8 + st1 {v24.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12], #8 + st1 {v28.s}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12] + b StoreDataEnd + +Write1: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S new file mode 100644 index 00000000..b60055bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S @@ -0,0 +1,788 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, +// float16_t *multi_scales, float16_t *bias, size_t row, size_t col, size_t stride, +// const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, +// int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWIForFp16 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b ConvertHalfPrecision + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b ConvertHalfPrecision + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b ConvertHalfPrecision + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] + +ConvertHalfPrecision: +// from single-precision convert to half-precision + fcvtn v16.4h,v16.4s + fcvtn v17.4h,v17.4s + fcvtn v18.4h,v18.4s + fcvtn v19.4h,v19.4s + + fcvtn v20.4h,v20.4s + fcvtn v21.4h,v21.4s + fcvtn v22.4h,v22.4s + fcvtn v23.4h,v23.4s + + fcvtn v24.4h,v24.4s + fcvtn v25.4h,v25.4s + fcvtn v26.4h,v26.4s + fcvtn v27.4h,v27.4s + + fcvtn v28.4h,v28.4s + fcvtn v29.4h,v29.4s + fcvtn v30.4h,v30.4s + fcvtn v31.4h,v31.4s + +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4h, v2.4h, v3.4h, v4.4h}, [x5] + + fadd v16.4h,v16.4h,v1.4h + fadd v17.4h,v17.4h,v2.4h + fadd v18.4h,v18.4h,v3.4h + fadd v19.4h,v19.4h,v4.4h + + fadd v20.4h,v20.4h,v1.4h + fadd v21.4h,v21.4h,v2.4h + fadd v22.4h,v22.4h,v3.4h + fadd v23.4h,v23.4h,v4.4h + + fadd v24.4h,v24.4h,v1.4h + fadd v25.4h,v25.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v27.4h,v27.4h,v4.4h + + fadd v28.4h,v28.4h,v1.4h + fadd v29.4h,v29.4h,v2.4h + fadd v30.4h,v30.4h,v3.4h + fadd v31.4h,v31.4h,v4.4h + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4h, wzr + + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + b StoreData + +Relu6: + dup v1.4h, wzr + movi v2.4h, #6 + scvtf v2.4h, v2.4h + + // max (out, 0) + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + // min (out, 6) + + smin v16.4h,v16.4h,v2.4h + smin v17.4h,v17.4h,v2.4h + smin v18.4h,v18.4h,v2.4h + smin v19.4h,v19.4h,v2.4h + + smin v20.4h,v20.4h,v2.4h + smin v21.4h,v21.4h,v2.4h + smin v22.4h,v22.4h,v2.4h + smin v23.4h,v23.4h,v2.4h + + smin v24.4h,v24.4h,v2.4h + smin v25.4h,v25.4h,v2.4h + smin v26.4h,v26.4h,v2.4h + smin v27.4h,v27.4h,v2.4h + + smin v28.4h,v28.4h,v2.4h + smin v29.4h,v29.4h,v2.4h + smin v30.4h,v30.4h,v2.4h + smin v31.4h,v31.4h,v2.4h + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2], x8 + st1 {v28.4h,v29.4h,v30.4h,v31.4h}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15], #4 + st1 {v19.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14], #4 + st1 {v23.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13], #4 + st1 {v27.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12], #4 + st1 {v31.h}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write13: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.h}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + b StoreDataEnd + +Write11: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15], #4 + st1 {v18.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14], #4 + st1 {v22.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13], #4 + st1 {v26.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12], #4 + st1 {v30.h}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write9: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.h}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4h,v17.4h}, [x15], #16 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + b StoreDataEnd + +Write7: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15], #4 + st1 {v17.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14], #4 + st1 {v21.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13], #4 + st1 {v25.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12], #4 + st1 {v29.h}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write5: + st1 {v16.4h}, [x15], #8 + st1 {v17.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.h}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4h}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.s}[0], [x15], #4 + st1 {v16.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14], #4 + st1 {v20.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13], #4 + st1 {v24.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12], #4 + st1 {v28.h}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd + +Write1: + st1 {v16.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.h}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S new file mode 100644 index 00000000..c2818043 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8.S @@ -0,0 +1,864 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, +// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, +// const int *multiplier, const int *left_shift, const int *right_shift, int row, +// int col, int stride, int peroc); + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row8 +// w4: col8 +// w5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8DpNeon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + ldr w24, [sp, #64] + ldr w27, [sp, #72] + + mov w17, #8 // sizeof(int8)*8 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col8 + beq End1 + + mov w16, w3 // reset a row8 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #16 + blt LoopD4 + +LoopD16: + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + + ld1 {v4.16b, v5.16b}, [x17], #32 + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + + ld1 {v6.16b, v7.16b}, [x28], #32 + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + ld1 {v8.16b, v9.16b}, [x17], #32 + sdot v16.4s, v6.16b, v4.4b[0] + sdot v18.4s, v6.16b, v4.4b[1] + sdot v20.4s, v6.16b, v4.4b[2] + sdot v22.4s, v6.16b, v4.4b[3] + + sdot v24.4s, v6.16b, v5.4b[0] + sdot v26.4s, v6.16b, v5.4b[1] + sdot v28.4s, v6.16b, v5.4b[2] + sdot v30.4s, v6.16b, v5.4b[3] + + ld1 {v10.16b, v11.16b}, [x28], #32 + sdot v17.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v4.4b[1] + sdot v21.4s, v7.16b, v4.4b[2] + sdot v23.4s, v7.16b, v4.4b[3] + + sdot v25.4s, v7.16b, v5.4b[0] + sdot v27.4s, v7.16b, v5.4b[1] + sdot v29.4s, v7.16b, v5.4b[2] + sdot v31.4s, v7.16b, v5.4b[3] + + ld1 {v12.16b, v13.16b}, [x17], #32 + sdot v16.4s, v10.16b, v8.4b[0] + sdot v18.4s, v10.16b, v8.4b[1] + sdot v20.4s, v10.16b, v8.4b[2] + sdot v22.4s, v10.16b, v8.4b[3] + + sdot v24.4s, v10.16b, v9.4b[0] + sdot v26.4s, v10.16b, v9.4b[1] + sdot v28.4s, v10.16b, v9.4b[2] + sdot v30.4s, v10.16b, v9.4b[3] + + ld1 {v14.16b, v15.16b}, [x28], #32 + sdot v17.4s, v11.16b, v8.4b[0] + sdot v19.4s, v11.16b, v8.4b[1] + sdot v21.4s, v11.16b, v8.4b[2] + sdot v23.4s, v11.16b, v8.4b[3] + + sdot v25.4s, v11.16b, v9.4b[0] + sdot v27.4s, v11.16b, v9.4b[1] + sdot v29.4s, v11.16b, v9.4b[2] + sdot v31.4s, v11.16b, v9.4b[3] + + sdot v16.4s, v14.16b, v12.4b[0] + sdot v18.4s, v14.16b, v12.4b[1] + sdot v20.4s, v14.16b, v12.4b[2] + sdot v22.4s, v14.16b, v12.4b[3] + + sdot v24.4s, v14.16b, v13.4b[0] + sdot v26.4s, v14.16b, v13.4b[1] + sdot v28.4s, v14.16b, v13.4b[2] + sdot v30.4s, v14.16b, v13.4b[3] + + sdot v17.4s, v15.16b, v12.4b[0] + sdot v19.4s, v15.16b, v12.4b[1] + sdot v21.4s, v15.16b, v12.4b[2] + sdot v23.4s, v15.16b, v12.4b[3] + + sdot v25.4s, v15.16b, v13.4b[0] + sdot v27.4s, v15.16b, v13.4b[1] + sdot v29.4s, v15.16b, v13.4b[2] + sdot v31.4s, v15.16b, v13.4b[3] + + subs w20, w20, #16 // depth - 16 + b L3 + +LoopD4: + cmp w20, #0 + beq End3 + + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + subs w20, w20, #4 // depth - 4 + b LoopD4 + +End3: + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + ld1 {v14.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v20.4s, v20.4s, v15.4s + add v22.4s, v22.4s, v15.4s + add v24.4s, v24.4s, v15.4s + add v26.4s, v26.4s, v15.4s + add v28.4s, v28.4s, v15.4s + add v30.4s, v30.4s, v15.4s + add v17.4s, v17.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v31.4s, v31.4s, v14.4s + + cmp w27, #0 + beq PerTSumLoad +PerCSumLoad: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + b ApplySum +PerTSumLoad: + ld1 {v14.4s}, [x22], #16 + ld1 {v15.4s}, [x22], #16 + dup v0.4s, v14.s[0] + dup v1.4s, v14.s[0] + dup v2.4s, v14.s[1] + dup v3.4s, v14.s[1] + dup v4.4s, v14.s[2] + dup v5.4s, v14.s[2] + dup v6.4s, v14.s[3] + dup v7.4s, v14.s[3] + dup v8.4s, v15.s[0] + dup v9.4s, v15.s[0] + dup v10.4s, v15.s[1] + dup v11.4s, v15.s[1] + dup v12.4s, v15.s[2] + dup v13.4s, v15.s[2] + dup v14.4s, v15.s[3] + dup v15.4s, v14.s[0] +ApplySum: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v9.4s + sub v26.4s, v26.4s, v10.4s + sub v27.4s, v27.4s, v11.4s + sub v28.4s, v28.4s, v12.4s + sub v29.4s, v29.4s, v13.4s + sub v30.4s, v30.4s, v14.4s + sub v31.4s, v31.4s, v15.4s + + cmp w27, #0 + beq PerTRoundLoad +PerCRoundLoad: + ld1 {v8.4s, v9.4s}, [x12] + ld1 {v10.4s, v11.4s}, [x11] + ld1 {v12.4s, v13.4s}, [x13] + b ApplyRound +PerTRoundLoad: + ld1 {v14.s}[0], [x12] + dup v8.4s, v14.s[0] + dup v9.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v10.4s, v14.s[0] + dup v11.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v12.4s, v14.s[0] + dup v13.4s, v14.s[0] +ApplyRound: + // Apply left shift + sqshl v16.4s, v16.4s, v8.4s + sqshl v17.4s, v17.4s, v9.4s + sqshl v18.4s, v18.4s, v8.4s + sqshl v19.4s, v19.4s, v9.4s + sqshl v20.4s, v20.4s, v8.4s + sqshl v21.4s, v21.4s, v9.4s + sqshl v22.4s, v22.4s, v8.4s + sqshl v23.4s, v23.4s, v9.4s + sqshl v24.4s, v24.4s, v8.4s + sqshl v25.4s, v25.4s, v9.4s + sqshl v26.4s, v26.4s, v8.4s + sqshl v27.4s, v27.4s, v9.4s + sqshl v28.4s, v28.4s, v8.4s + sqshl v29.4s, v29.4s, v9.4s + sqshl v30.4s, v30.4s, v8.4s + sqshl v31.4s, v31.4s, v9.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v10.4s + sqrdmulh v17.4s, v17.4s, v11.4s + sqrdmulh v18.4s, v18.4s, v10.4s + sqrdmulh v19.4s, v19.4s, v11.4s + sqrdmulh v20.4s, v20.4s, v10.4s + sqrdmulh v21.4s, v21.4s, v11.4s + sqrdmulh v22.4s, v22.4s, v10.4s + sqrdmulh v23.4s, v23.4s, v11.4s + sqrdmulh v24.4s, v24.4s, v10.4s + sqrdmulh v25.4s, v25.4s, v11.4s + sqrdmulh v26.4s, v26.4s, v10.4s + sqrdmulh v27.4s, v27.4s, v11.4s + sqrdmulh v28.4s, v28.4s, v10.4s + sqrdmulh v29.4s, v29.4s, v11.4s + sqrdmulh v30.4s, v30.4s, v10.4s + sqrdmulh v31.4s, v31.4s, v11.4s + + // Apply right shift + and v0.16b, v12.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v12.4s + and v1.16b, v13.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v13.4s + and v2.16b, v12.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v12.4s + and v3.16b, v13.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v13.4s + and v0.16b, v12.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v12.4s + and v1.16b, v13.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v13.4s + and v2.16b, v12.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v12.4s + and v3.16b, v13.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v13.4s + and v0.16b, v12.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v12.4s + and v1.16b, v13.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v13.4s + and v2.16b, v12.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v12.4s + and v3.16b, v13.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v13.4s + and v0.16b, v12.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v12.4s + and v1.16b, v13.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v13.4s + and v2.16b, v12.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v12.4s + and v3.16b, v13.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v13.4s + + // Add the destination zero point + dup v8.4s, w10 + add v16.4s, v16.4s, v8.4s + add v17.4s, v17.4s, v8.4s + add v18.4s, v18.4s, v8.4s + add v19.4s, v19.4s, v8.4s + add v20.4s, v20.4s, v8.4s + add v21.4s, v21.4s, v8.4s + add v22.4s, v22.4s, v8.4s + add v23.4s, v23.4s, v8.4s + add v24.4s, v24.4s, v8.4s + add v25.4s, v25.4s, v8.4s + add v26.4s, v26.4s, v8.4s + add v27.4s, v27.4s, v8.4s + add v28.4s, v28.4s, v8.4s + add v29.4s, v29.4s, v8.4s + add v30.4s, v30.4s, v8.4s + add v31.4s, v31.4s, v8.4s + + // Apply the act_min bound + dup v7.4s, w8 + smax v16.4s, v16.4s, v7.4s + smax v17.4s, v17.4s, v7.4s + smax v18.4s, v18.4s, v7.4s + smax v19.4s, v19.4s, v7.4s + smax v20.4s, v20.4s, v7.4s + smax v21.4s, v21.4s, v7.4s + smax v22.4s, v22.4s, v7.4s + smax v23.4s, v23.4s, v7.4s + smax v24.4s, v24.4s, v7.4s + smax v25.4s, v25.4s, v7.4s + smax v26.4s, v26.4s, v7.4s + smax v27.4s, v27.4s, v7.4s + smax v28.4s, v28.4s, v7.4s + smax v29.4s, v29.4s, v7.4s + smax v30.4s, v30.4s, v7.4s + smax v31.4s, v31.4s, v7.4s + + // Apply the act_max bound + dup v6.4s, w9 + smin v16.4s, v16.4s, v6.4s + smin v17.4s, v17.4s, v6.4s + smin v18.4s, v18.4s, v6.4s + smin v19.4s, v19.4s, v6.4s + smin v20.4s, v20.4s, v6.4s + smin v21.4s, v21.4s, v6.4s + smin v22.4s, v22.4s, v6.4s + smin v23.4s, v23.4s, v6.4s + smin v24.4s, v24.4s, v6.4s + smin v25.4s, v25.4s, v6.4s + smin v26.4s, v26.4s, v6.4s + smin v27.4s, v27.4s, v6.4s + smin v28.4s, v28.4s, v6.4s + smin v29.4s, v29.4s, v6.4s + smin v30.4s, v30.4s, v6.4s + smin v31.4s, v31.4s, v6.4s + + // int32 -> int16 + sqxtn v0.4h, v16.4s + sqxtn2 v0.8h, v17.4s + sqxtn v1.4h, v18.4s + sqxtn2 v1.8h, v19.4s + sqxtn v2.4h, v20.4s + sqxtn2 v2.8h, v21.4s + sqxtn v3.4h, v22.4s + sqxtn2 v3.8h, v23.4s + sqxtn v4.4h, v24.4s + sqxtn2 v4.8h, v25.4s + sqxtn v5.4h, v26.4s + sqxtn2 v5.8h, v27.4s + sqxtn v6.4h, v28.4s + sqxtn2 v6.8h, v29.4s + sqxtn v7.4h, v30.4s + sqxtn2 v7.8h, v31.4s + + // int16 -> int8 + sqxtn v8.8b, v0.8h + sqxtn2 v8.16b, v1.8h + sqxtn v9.8b, v2.8h + sqxtn2 v9.16b, v3.8h + sqxtn v10.8b, v4.8h + sqxtn2 v10.16b, v5.8h + sqxtn v11.8b, v6.8h + sqxtn2 v11.16b, v7.8h + + cmp w23, #8 + blt Write // if rows < 8 + cmp w15, #8 + blt Write // if cols < 8 + + st1 {v8.d}[0], [x2], x24 + st1 {v8.d}[1], [x2], x24 + st1 {v9.d}[0], [x2], x24 + st1 {v9.d}[1], [x2], x24 + st1 {v10.d}[0], [x2], x24 + st1 {v10.d}[1], [x2], x24 + st1 {v11.d}[0], [x2], x24 + st1 {v11.d}[1], [x2], x24 + b Endwrite + +Write: + cmp w15, #8 + bge WriteCol8 + cmp w15, #7 + beq WriteCol7 + cmp w15, #6 + beq WriteCol6 + cmp w15, #5 + beq WriteCol5 + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol8: + st1 {v8.d}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.d}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.d}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.d}[1], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.d}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.d}[1], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.d}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.d}[1], [x2], x24 + b Endwrite + +WriteCol7: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + st1 {v8.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + st1 {v8.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + st1 {v9.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + st1 {v9.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + st1 {v10.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + st1 {v10.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + st1 {v11.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + st1 {v11.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol6: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + add x2, x2, x24 + b Endwrite + +WriteCol5: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.b}[12], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol4: + st1 {v8.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.s}[2], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.s}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.s}[2], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.s}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.s}[2], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.s}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.s}[2], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v8.h}[0], [x26], #2 + st1 {v8.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.h}[4], [x26], #2 + st1 {v8.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.h}[0], [x26], #2 + st1 {v9.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.h}[4], [x26], #2 + st1 {v9.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.h}[0], [x26], #2 + st1 {v10.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.h}[4], [x26], #2 + st1 {v10.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.h}[0], [x26], #2 + st1 {v11.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.h}[4], [x26], #2 + st1 {v11.b}[10], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + st1 {v8.h}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.h}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.h}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.h}[4], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.h}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.h}[4], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.h}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.h}[4], [x2], x24 + b Endwrite + +WriteCol1: + st1 {v8.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.b}[8], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.b}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.b}[8], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.b}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.b}[8], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.b}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.b}[8], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #8 // a row8 counter - 8 + sub w23, w23, #8 // a row counter - 8 + b L2 + +End2: + sub w4, w4, #8 // b col8 counter - 8 + sub w15, w15, #8 // b col counter - 8 + add x1, x1, x21 // b ptr + stride + add x7, x7, #32 // bias ptr + stride + add x25, x25, #8 // output + stride(8 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #32 + add x11, x11, #32 + add x13, x13, #32 +PerTEnd2: + b L1 + +End1: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S new file mode 100644 index 00000000..ee119b1a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulDpInt8Opt.S @@ -0,0 +1,1098 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep4, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, +// const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, +// const int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row +// x4: col +// x5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8DpOpt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + stp x29, x30, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step + +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x25, x6 // reload input_sum ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #64] // reload filter_zp + + LoopCol: + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x17, #4 + ble LoopDepthQuarter + cmp x17, #8 + ble LoopDepthHalf + + LoopDepth: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x16], #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepth + + Bias: + cbz x7, NoReadBias + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x29], #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v18.4s, v18.4s, v2.4s + add v19.4s, v19.4s, v3.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v22.4s, v22.4s, v2.4s + add v23.4s, v23.4s, v3.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v26.4s, v26.4s, v2.4s + add v27.4s, v27.4s, v3.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + add v30.4s, v30.4s, v2.4s + add v31.4s, v31.4s, v3.4s + + NoReadBias: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSum + + PerTensorSum: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v18.4s, v18.4s, v12.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v22.4s, v22.4s, v13.4s + sub v23.4s, v23.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v26.4s, v26.4s, v14.4s + sub v27.4s, v27.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + sub v30.4s, v30.4s, v15.4s + sub v31.4s, v31.4s, v15.4s + + b PerTensor + + PerChannelSum: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v2.4s, v10.4s, v12.4s + mul v3.4s, v11.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + mul v6.4s, v10.4s, v13.4s + mul v7.4s, v11.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + mul v0.4s, v8.4s, v14.4s + mul v1.4s, v9.4s, v14.4s + mul v2.4s, v10.4s, v14.4s + mul v3.4s, v11.4s, v14.4s + mul v4.4s, v8.4s, v15.4s + mul v5.4s, v9.4s, v15.4s + mul v6.4s, v10.4s, v15.4s + mul v7.4s, v11.4s, v15.4s + sub v24.4s, v24.4s, v0.4s + sub v25.4s, v25.4s, v1.4s + sub v26.4s, v26.4s, v2.4s + sub v27.4s, v27.4s, v3.4s + sub v28.4s, v28.4s, v4.4s + sub v29.4s, v29.4s, v5.4s + sub v30.4s, v30.4s, v6.4s + sub v31.4s, v31.4s, v7.4s + + PerTensor: + cbnz x15, PerChannel + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + mov v6.16b, v4.16b + mov v7.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + mov v10.16b, v8.16b + mov v11.16b, v8.16b + + b Quantization + + PerChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x11], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x13], #64 + + Quantization: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v18.4s, v18.4s, v2.4s + sqshl v19.4s, v19.4s, v3.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v22.4s, v22.4s, v2.4s + sqshl v23.4s, v23.4s, v3.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v26.4s, v26.4s, v2.4s + sqshl v27.4s, v27.4s, v3.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + sqshl v30.4s, v30.4s, v2.4s + sqshl v31.4s, v31.4s, v3.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v18.4s, v18.4s, v6.4s + sqrdmulh v19.4s, v19.4s, v7.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v22.4s, v22.4s, v6.4s + sqrdmulh v23.4s, v23.4s, v7.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v26.4s, v26.4s, v6.4s + sqrdmulh v27.4s, v27.4s, v7.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + sqrdmulh v30.4s, v30.4s, v6.4s + sqrdmulh v31.4s, v31.4s, v7.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + and v2.16b, v10.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v10.4s + and v3.16b, v11.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v11.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + and v2.16b, v10.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v10.4s + and v3.16b, v11.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v11.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + and v2.16b, v10.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v10.4s + and v3.16b, v11.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v11.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + and v2.16b, v10.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v10.4s + and v3.16b, v11.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v11.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v18.4s, v18.4s, v6.4s + add v19.4s, v19.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v22.4s, v22.4s, v6.4s + add v23.4s, v23.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v26.4s, v26.4s, v6.4s + add v27.4s, v27.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + add v30.4s, v30.4s, v6.4s + add v31.4s, v31.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v18.4s, v18.4s, v0.4s + smax v19.4s, v19.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v22.4s, v22.4s, v0.4s + smax v23.4s, v23.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v26.4s, v26.4s, v0.4s + smax v27.4s, v27.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + smax v30.4s, v30.4s, v0.4s + smax v31.4s, v31.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v18.4s, v18.4s, v1.4s + smin v19.4s, v19.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v22.4s, v22.4s, v1.4s + smin v23.4s, v23.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v26.4s, v26.4s, v1.4s + smin v27.4s, v27.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + smin v30.4s, v30.4s, v1.4s + smin v31.4s, v31.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + sqxtn v18.4h, v18.4s + sqxtn2 v18.8h, v19.4s + sqxtn2 v0.16b, v18.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + sqxtn v22.4h, v22.4s + sqxtn2 v22.8h, v23.4s + sqxtn2 v1.16b, v22.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn2 v2.16b, v26.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + sqxtn v30.4h, v30.4s + sqxtn2 v30.8h, v31.4s + sqxtn2 v3.16b, v30.8h + + b WriteStart + + LoopDepthHalf: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthHalf + + BiasHalf: + cbz x7, NoReadBiasHalf + ld1 {v0.4s, v1.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + + NoReadBiasHalf: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumHalf + + PerTensorSumHalf: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + + b PerTensorHalf + + PerChannelSumHalf: + ld1 {v8.4s, v9.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + mul v2.4s, v8.4s, v14.4s + mul v3.4s, v9.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + mul v7.4s, v9.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v25.4s, v25.4s, v3.4s + sub v28.4s, v28.4s, v6.4s + sub v29.4s, v29.4s, v7.4s + + PerTensorHalf: + cbnz x15, PerChannelHalf + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + + b QuantizationHalf + + PerChannelHalf: + ld1 {v0.4s, v1.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s, v5.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s, v9.4s}, [x13] + add x13, x13, #64 + + QuantizationHalf: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + LoopDepthQuarter: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthQuarter + + BiasQuarter: + cbz x7, NoReadBiasQuarter + ld1 {v0.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v20.4s, v20.4s, v0.4s + add v24.4s, v24.4s, v0.4s + add v28.4s, v28.4s, v0.4s + + NoReadBiasQuarter: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumQuarter + + PerTensorSumQuarter: + sub v16.4s, v16.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + + b PerTensorQuarter + + PerChannelSumQuarter: + ld1 {v8.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v20.4s, v20.4s, v4.4s + mul v2.4s, v8.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v28.4s, v28.4s, v6.4s + + PerTensorQuarter: + cbnz x15, PerChannelQuarter + ld1r {v0.4s}, [x12] + ld1r {v4.4s}, [x11] + ld1r {v8.4s}, [x13] + + b QuantizationHalf + + PerChannelQuarter: + ld1 {v0.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s}, [x13] + add x13, x13, #64 + + QuantizationQuarter: + sqshl v16.4s, v16.4s, v0.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v28.4s, v28.4s, v0.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v28.4s, v28.4s, v4.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v28.4s, v28.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + WriteStart: + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + cmp x17, #4 + beq Write4 + cmp x17, #5 + beq Write5 + cmp x17, #6 + beq Write6 + cmp x17, #7 + beq Write7 + cmp x17, #8 + beq Write8 + cmp x17, #9 + beq Write9 + cmp x17, #10 + beq Write10 + cmp x17, #11 + beq Write11 + cmp x17, #12 + beq Write12 + cmp x17, #13 + beq Write13 + cmp x17, #14 + beq Write14 + cmp x17, #15 + beq Write15 + b Write16 + + Write1: + add x27, x27, #1 + st1 {v0.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.b}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.b}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.b}[0], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v0.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v0.h}[0], [x19], x14 + st1 {v0.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + st1 {v1.b}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + st1 {v2.b}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + st1 {v3.b}[2], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v0.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + b WriteEnd + Write5: + add x27, x27, #5 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.b}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.b}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.b}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.b}[4], [x22], x14 + b WriteEnd + Write6: + add x27, x27, #6 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + b WriteEnd + Write7: + add x27, x27, #7 + add x22, x19, #4 + add x26, x19, #6 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + st1 {v0.b}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + st1 {v1.b}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + st1 {v2.b}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + st1 {v3.b}[6], [x26], x14 + b WriteEnd + Write8: + add x27, x27, #8 + st1 {v0.8b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + b WriteEnd + Write9: + add x27, x27, #9 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.b}[8], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.b}[8], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.b}[8], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.b}[8], [x22], x14 + b WriteEnd + Write10: + add x27, x27, #10 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + b WriteEnd + Write11: + add x27, x27, #11 + add x22, x19, #8 + add x26, x19, #10 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + st1 {v0.b}[10], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + st1 {v1.b}[10], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + st1 {v2.b}[10], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + st1 {v3.b}[10], [x26], x14 + b WriteEnd + Write12: + add x27, x27, #12 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + b WriteEnd + Write13: + add x27, x27, #13 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.b}[12], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.b}[12], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.b}[12], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.b}[12], [x26], x14 + b WriteEnd + Write14: + add x27, x27, #14 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + b WriteEnd + Write15: + add x27, x27, #15 + add x22, x19, #8 + add x26, x19, #12 + add x21, x19, #14 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + st1 {v0.b}[14], [x21], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + st1 {v1.b}[14], [x21], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + st1 {v2.b}[14], [x21], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + st1 {v3.b}[14], [x21], x14 + b WriteEnd + Write16: + add x27, x27, #16 + st1 {v0.16b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.16b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.16b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.16b}, [x19], x14 + + WriteEnd: + subs x17, x17, #16 + ble LoopColEnd + mov x25, x6 + b LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + sub sp, sp, #224 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S new file mode 100644 index 00000000..28db29cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly/opt/MatmulOptR4Int8.S @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl_c/assembly_global.h" +.text +.align 5 + +//void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulOptR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + sdot v16.4s, v4.16b, v0.16b + sdot v17.4s, v5.16b, v0.16b + sdot v18.4s, v6.16b, v0.16b + sdot v19.4s, v7.16b, v0.16b + sdot v20.4s, v4.16b, v1.16b + sdot v21.4s, v5.16b, v1.16b + sdot v22.4s, v6.16b, v1.16b + sdot v23.4s, v7.16b, v1.16b + sdot v24.4s, v4.16b, v2.16b + sdot v25.4s, v5.16b, v2.16b + sdot v26.4s, v6.16b, v2.16b + sdot v27.4s, v7.16b, v2.16b + sdot v28.4s, v4.16b, v3.16b + sdot v29.4s, v5.16b, v3.16b + sdot v30.4s, v6.16b, v3.16b + sdot v31.4s, v7.16b, v3.16b + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h new file mode 100644 index 00000000..d1f5ca8b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/assembly_global.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ASSEMBLY_GLOBAL_H +#define NNACL_ASSEMBLY_GLOBAL_H + +// clang-format off +.macro asm_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.hidden \fname +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format off +.macro asm_default_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format on + +#endif // NNACL_ASSEMBLY_GLOBAL_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h new file mode 100644 index 00000000..f02a87fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/attention_parameter.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ATTENTION_PARAMETER_H_ +#define NNACL_ATTENTION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct AttentionParameter { + OpParameter op_parameter_; + int head_num_; + int head_size_; + bool cross_; +} AttentionParameter; + +typedef struct RelativePositionAttentionParameter { + // Primitive parameter + OpParameter op_parameter_; + // multi-head-attention args + int num_heads_; // number of heads of multi-head-attention + int k_seq_; // length of sequence of key of attention + int v_seq_; // length of sequence of value of attention + bool use_bias_; // if matmul in attention has bias + // relative-position-attention args + int p_seq_; // length of sequence of position of attention + // args for compute + int batch_; // batch of query/key/value/position + int d_model_; // d_model of multi-head-attention + int q_seq_; // length of sequence of query of attention + int row_tile_; // row tile for matrix pack + int col_tile_; // col tile for matrix pack + int bias_tile_; // tile for bias pack +} RelativePositionAttentionParameter; + +#endif // NNACL_ATTENTION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c new file mode 100644 index 00000000..793aceef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/kernel/arithmetic.h" + +void CalcMultiplesAndStrides(ArithmeticParameter *param) { + for (size_t i = 0; i < param->ndim_; i++) { + if (param->in_shape0_[i] != 0) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + } + if (param->in_shape1_[i] != 0) { + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); + ComputeStrides(param->in_shape1_, param->in_strides1_, param->ndim_); + ComputeStrides(param->out_shape_, param->out_strides_, param->ndim_); +} + +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic) { + for (size_t i = 0; i < arithmetic->ndim_; i++) { + if (arithmetic->in_shape0_[i] != 0) { + arithmetic->multiples0_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape0_[i]; + } + if (arithmetic->in_shape1_[i] != 0) { + arithmetic->multiples1_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(arithmetic->in_shape0_, arithmetic->in_strides0_, arithmetic->ndim_); + ComputeStrides(arithmetic->in_shape1_, arithmetic->in_strides1_, arithmetic->ndim_); + ComputeStrides(arithmetic->out_shape_, arithmetic->out_strides_, arithmetic->ndim_); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h new file mode 100644 index 00000000..af095319 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/arithmetic_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_ARITHMETIC_BASE_H_ +#define NNACL_BASE_ARITHMETIC_BASE_H_ + +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/nnacl_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/kernel/arithmetic.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void CalcMultiplesAndStrides(ArithmeticParameter *param); +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_ARITHMETIC_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c new file mode 100644 index 00000000..6e08f6f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.c @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/batch_to_space_base.h" + +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size) { + int block_h = block[0]; + int block_w = block[1]; + if (block_h == 0 || block_w == 0) { + return; + } + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h new file mode 100644 index 00000000..c85dd380 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/batch_to_space_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_BASE_H_ +#define NNACL_BATCH_TO_SPACE_BASE_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size); +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BATCH_TO_SPACE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c new file mode 100644 index 00000000..8853abf0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.c @@ -0,0 +1,106 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/broadcast_to.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +size_t accumulate(const int *shape, int start, int end) { + size_t product = 1; + for (int i = start; i <= end; ++i) { + product *= (size_t)shape[i]; + } + return product; +} + +void pad_input_shape(int *input_shape, int input_shape_len, int output_shape_len) { + if (input_shape_len < output_shape_len) { + const int shape_gap = output_shape_len - input_shape_len; + for (int i = input_shape_len - 1; i >= 0; --i) { + input_shape[i + shape_gap] = input_shape[i]; + } + for (int i = 0; i < shape_gap; ++i) { + input_shape[i] = 1; + } + } +} + +#define BROADCAST_TO_SIZE_IMPL(data_size) \ + int BroadcastToSize##data_size(const void *input, BroadcastShapeInfo *shape_info, void *output) { \ + if (input == NULL || output == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + if (shape_info->output_shape_size_ > MAX_SHAPE_SIZE) { \ + return NNACL_ERR; \ + } \ + int *input_shape = shape_info->input_shape_; \ + const int *output_shape = shape_info->output_shape_; \ + const int dim_max = shape_info->output_shape_size_ - 1; \ + const size_t temp_length = accumulate(output_shape, 0, dim_max); \ + const size_t data_len = data_size / BYTE_SIZE; \ + if (temp_length * data_len == 0) { \ + return NNACL_ERR; \ + } \ + int8_t *data_temp = (int8_t *)malloc(temp_length * data_len); \ + if (data_temp == NULL) { \ + return NNACL_ERR; \ + } \ + pad_input_shape(input_shape, shape_info->input_shape_size_, dim_max + 1); \ + shape_info->input_shape_size_ = dim_max + 1; \ + \ + size_t before_dim_elements_num = accumulate(input_shape, 0, dim_max - 1); \ + size_t after_dim_elements_num = (size_t)(input_shape[dim_max]); \ + size_t dim_broadcast_rate = (size_t)(output_shape[dim_max] / input_shape[dim_max]); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + const int8_t *in_ptr = (const int8_t *)input + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = (int8_t *)output + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + \ + int dim_index = dim_max - 1; \ + while (dim_index >= 0) { \ + if (input_shape[dim_index] == 0) { \ + free(data_temp); \ + return NNACL_ERR; \ + } \ + dim_broadcast_rate = (size_t)(output_shape[dim_index] / input_shape[dim_index]); \ + if (dim_broadcast_rate > 1) { \ + before_dim_elements_num = accumulate(input_shape, 0, dim_index - 1); \ + after_dim_elements_num = accumulate(output_shape, dim_index + 1, dim_max); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + int8_t *in_ptr = (int8_t *)output + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = data_temp + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + size_t elements_total = before_dim_elements_num * dim_broadcast_rate * after_dim_elements_num; \ + memcpy(output, data_temp, elements_total *data_len); \ + } \ + --dim_index; \ + } \ + free(data_temp); \ + return NNACL_OK; \ + } + +BROADCAST_TO_SIZE_IMPL(8) +BROADCAST_TO_SIZE_IMPL(16) +BROADCAST_TO_SIZE_IMPL(32) +BROADCAST_TO_SIZE_IMPL(64) +BROADCAST_TO_SIZE_IMPL(128) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h new file mode 100644 index 00000000..d13114b0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/broadcast_to.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_BROADCAST_TO_H_ +#define NNACL_FP32_BROADCAST_TO_H_ + +#include "nnacl_c/broadcast_to_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define BYTE_SIZE 8 +int BroadcastToSize8(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize16(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize32(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize64(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize128(const void *input, BroadcastShapeInfo *shape_info, void *output); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BROADCAST_TO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c new file mode 100644 index 00000000..6d0abc46 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.c @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/cast_base_simd.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +uint16_t Float32ToFloat16_(float f) { + float32_bits hbit; + hbit.f = f; + uint16_t hbits = 0; + // Extract the sign bit + uint16_t sign = (hbit.u >> FP16_BIT_SIZE) & 0x8000; // Get the sign (1 bit) ox8000 + // Extract the exponent + uint32_t exponent = (hbit.u >> FP32_SIGNIFICAND) & 0xFF; // Extract the exponent (8 bits) 0xFF + // Handle special cases (NaN, Inf, 0) + if (exponent == 0xFF) { // NaN or Infinity 0xFF + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) + return hbits; + } else if (exponent == 0) { // Zero or denormalized number + // In float16, we treat zero the same way + hbits |= sign; // Preserve sign for zero + return hbits; + } + // Adjust the exponent to fit float16 + exponent -= FP32_EXPONENT_BIAS; // Remove float32 bias + exponent += FP16_EXPONENT_BIAS; // Add float16 bias + // Check for overflow + if (exponent >= 0x1F) { // 0X1F + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) 0x7FFF + return hbits; + } + if (exponent == 0) { + // Handle underflow (too small to represent) + return sign; // Return zero with the correct sign + } + // Shift the mantissa: + // Extract the mantissa (23 bits), shift right by 13 (10-exp) + uint32_t mantissa = (hbit.u & 0x7FFFFF) >> FP16_SHIFT; // 0x7FFFFF + // Combine sign, exponent, and mantissa into hbits + hbits |= + sign | ((uint16_t)exponent << FP16_SIGNIFICAND) | (mantissa & 0x3FF); // combine sign exponent and mantissa 0x3FF + return hbits; +} + +void Int32ToFloat32(const int32_t *input, float *output, int number) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Int32ToFloat32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (float)input[index]; + } +} + +void Float32ToInt32(const float *input, int32_t *output, int number) { + int index = 0; + + SIMD_RUN_X86_NO_SCALAR(Float32ToInt32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (int32_t)input[index]; + } +} + +void BoolToFloat32(const bool *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Uint8ToFloat32(const uint8_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Int32ToFp16(const int32_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void BoolToFp16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Float32ToFp16(const float *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)(input[i]); + } +} + +void Fp16ToFloat32(const float16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)(input[i]); + } +} +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ShortToFloat32(input[i]); + } +} + +void Float32ToFp16(const float *input, uint16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = Float32ToFloat16_(input[i]); + } +} +#endif + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int32ToInt64(const int32_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int64ToInt32(const int64_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +void Float32ToInt16(const float *input, int16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int16_t)input[i]; + } +} + +void BoolToInt32(const bool *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + if (input[i]) { + output[i] = 1; + } else { + output[i] = 0; + } + } +} + +void Float32ToBool(const float *input, bool *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (bool)input[i]; + } +} + +void Float32ToUint8(const float *input, uint8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (uint8_t)input[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h new file mode 100644 index 00000000..52db13c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_CAST_BASE_H_ +#define NNACL_BASE_CAST_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BoolToFloat32(const bool *input, float *output, int number); + +void Uint8ToFloat32(const uint8_t *input, float *output, int number); + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number); + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number); + +void Int32ToFp16(const int32_t *input, float16_t *output, int number); + +void BoolToFp16(const bool *input, float16_t *output, int number); + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number); + +void Float32ToFp16(const float *input, float16_t *output, int number); + +void Fp16ToFloat32(const float16_t *input, float *output, int number); +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number); + +void Float32ToFp16(const float *input, uint16_t *output, int number); +#endif + +uint16_t Float32ToFloat16_(float f); + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number); + +void Int32ToInt64(const int32_t *input, int64_t *output, int number); + +void Int64ToInt32(const int64_t *input, int32_t *output, int number); + +void Float32ToInt16(const float *input, int16_t *output, int number); + +void BoolToInt32(const bool *input, int32_t *output, int number); + +void Float32ToBool(const float *input, bool *output, int number); + +void Float32ToUint8(const float *input, uint8_t *output, int number); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CAST_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in new file mode 100644 index 00000000..b7ad2f32 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/cast_base_simd.h.in @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Int32ToFloat32@SIMD_INSTRUCTION@(int index, const int32_t *input, float *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 value = SIMD_LD_EPI32(input + index); + SIMD_ST_F32(output + index, SIMD_EPI32_TO_F32(value)); + } + return index; +} + +#ifndef MS_SIMD_NEON +static inline int Float32ToInt32@SIMD_INSTRUCTION@(int index, const float *input, int32_t *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(input + index); + SIMD_ST_EPI32(output + index, SIMD_F32_TO_EPI32(value)); + } + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c new file mode 100644 index 00000000..b40f4473 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.c @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/concat_base.h" + +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size) { + int before_axis_size = 1; + for (int i = 0; i < axis; ++i) { + before_axis_size *= inputs_output_shape[0][i]; + } + + int after_axis_size = data_size; + for (size_t i = (size_t)(axis) + 1; i < shape_size; ++i) { + after_axis_size *= inputs_output_shape[0][i]; + } + int axis_offset = 0; + uint8_t *dst_base = (output); + int output_stride = after_axis_size * inputs_output_shape[input_num][axis]; + for (int i = 0; i < input_num; ++i) { + const uint8_t *src_base = (input[i]); + if (inputs_output_shape[i] == NULL) { + continue; + } + int input_stride = after_axis_size * inputs_output_shape[i][axis]; + NNACL_CHECK_ZERO_RETURN(thread_num); + int offset = UP_DIV(input_stride, thread_num); + int count = input_stride - offset * task_id; + if (count <= 0) { + axis_offset += inputs_output_shape[i][axis]; + continue; + } + count = MSMIN(offset, count); + for (int j = 0; j < before_axis_size; j++) { + const uint8_t *src = src_base + j * input_stride + task_id * offset; + uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset; + memcpy(dst, src, count); + } + axis_offset += inputs_output_shape[i][axis]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h new file mode 100644 index 00000000..ea85f6e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/concat_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONCAT_BASE_H_ +#define NNACL_FP32_CONCAT_BASE_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONCAT_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c new file mode 100644 index 00000000..240b8516 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/conv1x1_base.h" + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { + /* support nhwc */ + char *src = (char *)src_ptr; + char *dst = (char *)dst_ptr; + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; + char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, + src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h new file mode 100644 index 00000000..6ab0322b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv1x1_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CONV1X1_BASE_H_ +#define NNACL_BASE_CONV1X1_BASE_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CONV1X1_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c new file mode 100644 index 00000000..59aac49c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.c @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/errorcode.h" + +#define MIN_UNIT 2 +#define MAX_UNIT 8 + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) { + return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 && + conv_param->input_channel_ == conv_param->output_channel_ && conv_param->output_w_ >= 4 && + conv_param->output_h_ >= thread_num * 4; // better had more than 4 rows for each thread +} +#endif + +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit) { + if (input_unit != 4 && input_unit != 6 && input_unit != 8) { + return false; + } + if ((output_unit >= input_unit) || (output_unit < 2)) { + return false; + } + return true; +} + +// Reference to the paper "Fast Algorithms for Convolutional Neural Networks" +// Utilize cost model to compute performance gain. +// If the gain is greater than got from Im2col, winograd algorithm will be chosen. +int SelectOutputUnit(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + if (conv_param->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int unit2 = UP_DIV(out_w * out_h, C12NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT ? max_out_unit : MAX_UNIT; + max_out_unit = max_out_unit > MIN_UNIT ? max_out_unit : MIN_UNIT; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!CheckWinogradInputOutputUnit(input_unit, i)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ != 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_w_, conv_param->input_h_, false); + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ % C8NUM == 0 && + (conv_param->input_w_ * conv_param->input_h_ >= thread_nr_)) { + return true; + } + } else { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ > C128NUM) { // conv1d kernel + return false; + } else if (conv_param->input_channel_ / conv_param->op_parameter_.thread_num_ <= C16NUM && + conv_param->input_h_ >= thread_nr_ && + (conv_param->kernel_h_ < C7NUM || conv_param->input_h_ / conv_param->kernel_h_ >= C4NUM) && + (conv_param->kernel_w_ < C7NUM || conv_param->input_w_ / conv_param->kernel_w_ >= C4NUM)) { + return true; + } + } + return false; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h new file mode 100644 index 00000000..29dfa569 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/conv_common_base.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CONV_DEPTHWISE_BASE_H_ +#define NNACL_BASE_CONV_DEPTHWISE_BASE_H_ + +#include "nnacl_c/conv_parameter.h" + +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param); +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_); + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num); +#endif + +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit); + +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CONV_DEPTHWISE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c new file mode 100644 index 00000000..26b58e0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/crop_base.h" +#include "nnacl_c/errorcode.h" + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset) { + int64_t axis = crop_para->axis_; + int offsets_size = crop_para->offset_size_; + if (offsets_size > 1) { + NNACL_CHECK_TRUE_RET(axis + offsets_size == input_dim, NNACL_ERR); + } + for (int i = 0; i < input_dim; i++) { + int crop_offset = 0; + if (i >= axis) { + if (offsets_size == 1) { + crop_offset = crop_para->offset_[0]; + } else if (offsets_size > 1) { + if (i - axis < CROP_OFFSET_MAX_SIZE) { + crop_offset = crop_para->offset_[i - axis]; + } + } + } + in_offset[i] = crop_offset; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h new file mode 100644 index 00000000..4b036d88 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/crop_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CROP_BASE_H_ +#define NNACL_BASE_CROP_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +#ifdef __cplusplus +extern "C" { +#endif + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CROP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c new file mode 100644 index 00000000..11bd8478 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.c @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/errorcode.h" + +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = (size_t)block_size * param->out_stride_dim2_ * param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = (out_offset_w + l * param->out_stride_dim1_) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l * block_size * param->out_stride_dim2_) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} + +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim3 = in_shape[3]; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < in_shape_dim3; ++l) { + int64_t offset = l % (block_size * block_size); + int64_t out_offset_c = + out_offset_w + + offset / block_size * block_size * in_shape_dim2 * in_shape_dim3 / (block_size * block_size) + + offset % block_size * in_shape_dim3 / (block_size * block_size); + int64_t out_offset = (out_offset_c + l / (block_size * block_size)) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h new file mode 100644 index 00000000..57c0e38a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/depth_to_space_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_H_ +#define NNACL_DEPTH_TO_SPACE_H_ + +#include +#include "nnacl_c/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c new file mode 100644 index 00000000..9e6a05b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.c @@ -0,0 +1,59 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/fill_base.h" +#include "nnacl_c/fill_base_simd.h" + +int FillFp32(float *output, int size, float data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillFp32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} + +int FillInt32(int *output, int size, int data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillInt32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} + +int FillBool(bool *output, int size, bool data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + for (int index = 0; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h new file mode 100644 index 00000000..8da977c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FILL_BASE_H_ +#define NNACL_FILL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fill_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int FillFp32(float *output, int size, float data); +int FillInt32(int *output, int size, int data); +int FillBool(bool *output, int size, bool data); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in new file mode 100644 index 00000000..08bfb4c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/fill_base_simd.h.in @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int FillFp32@SIMD_INSTRUCTION@(int index, float *output, int size, float data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MOV_F32(data)); + } + return index; +} + +static inline int FillInt32@SIMD_INSTRUCTION@(int index, int *output, int size, int data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MOV_EPI32(data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c new file mode 100644 index 00000000..e062b901 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c @@ -0,0 +1,81 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/format_transpose.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/pack_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pack_fp16.h" +#endif + +int TransposeFp32Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, + const int batch, const int channel, const int plane) { + if (src_format == Format_NHWC && dst_format == Format_NCHW) { + PackNHWCToNCHWFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NHWC) { + PackNCHWToNHWCFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NC4HW4) { + PackNCHWToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NCHW) { + PackNC4HW4ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC4HW4) { + PackNHWCToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NHWC) { + PackNC4HW4ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + PackNHWCToNC8HW8Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +#ifdef ENABLE_FP16 +int TransposeFp16Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, int batch, + int channel, int plane) { + if (src_format == Format_NCHW && dst_format == Format_NC8HW8) { + PackNCHWFp16ToNC8HW8Fp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + return NNACL_ERR; + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp16((float16_t *)src_data, (float16_t *)dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} +#endif + +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane) { + switch (data_type) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + return TransposeFp32Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + return TransposeFp16Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#endif + default: + return NNACL_ERR; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h new file mode 100644 index 00000000..638e2f0f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/format_transpose.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FORMAT_TRANSPOSE_H_ +#define NNACL_FORMAT_TRANSPOSE_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c new file mode 100644 index 00000000..8721b568 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.c @@ -0,0 +1,44 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/base/gather_base.h" + +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index) { + if (input == NULL || output == NULL || indices == NULL || error_index == NULL) { + return NNACL_NULL_PTR; + } + const int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + int64_t in_stride = byte_inner_size * limit; + for (int64_t m = 0; m < outer_size; ++m) { + int8_t *int8_out_m = int8_out; + for (int64_t i = 0; i < index_num; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + if (index < 0 || index >= limit) { + *error_index = index; + return NNACL_GATHER_INDICES_VALUE_INVALID; + } else { + memcpy(int8_out_m, int8_in + index * byte_inner_size, byte_inner_size); + } + int8_out_m += byte_inner_size; + } + int8_in += in_stride; + int8_out += byte_out_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h new file mode 100644 index 00000000..f47a19b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_BASE_H_ +#define NNACL_GATHER_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c new file mode 100644 index 00000000..93d460b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.c @@ -0,0 +1,163 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/base/gather_d_base.h" + +int CheckIndexValue_int32_t(int32_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int CheckIndexValue_int64_t(int64_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int InitCalVec(size_t *in_strides, size_t *out_strides, size_t *pos, const size_t *input_shape, + const size_t input_shape_size, const size_t *output_shape, const size_t output_shape_size) { + // in_strides + NNACL_CHECK_NULL_RETURN_ERR(in_strides); + for (size_t i = 0; i < input_shape_size; ++i) { + in_strides[i] = 1; + } + for (int i = (int)input_shape_size - 2; i >= 0; --i) { + in_strides[i] = input_shape[i + 1] * in_strides[i + 1]; + } + + // out_strides + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + for (size_t i = 0; i < output_shape_size; ++i) { + out_strides[i] = 1; + } + for (int i = (int)output_shape_size - 2; i >= 0; --i) { + out_strides[i] = output_shape[i + 1] * out_strides[i + 1]; + } + + NNACL_CHECK_NULL_RETURN_ERR(pos); + for (size_t i = 0; i < output_shape_size; ++i) { + pos[i] = 0; + } + return NNACL_OK; +} + +#define COPY_TASK_IMPL(type0, type1) \ + int CopyTask_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, const type1 *index, size_t cur_dim, size_t *pos, const size_t dim, \ + const size_t *output_shape, const size_t output_shape_size, const size_t *in_strides, const size_t *out_strides) { \ + if (pos == NULL || out_strides == NULL || in_strides == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + for (size_t i = 0; i < output_shape[cur_dim]; ++i) { \ + pos[cur_dim] = i; \ + if (cur_dim == output_shape_size - 1) { \ + size_t input_offset = 0; \ + size_t out_offset = 0; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + out_offset += pos[j] * out_strides[j]; \ + } \ + size_t cur_index = pos[dim]; \ + pos[dim] = index[out_offset]; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + input_offset += pos[j] * in_strides[j]; \ + } \ + ((type0 *)output)[out_offset] = ((const type0 *)input)[input_offset]; \ + pos[dim] = cur_index; \ + } else { \ + CopyTask_Input_##type0##_Index_##type1(output, input, index, cur_dim + 1, pos, dim, output_shape, \ + output_shape_size, in_strides, out_strides); \ + } \ + } \ + return NNACL_OK; \ + } + +COPY_TASK_IMPL(bool, int32_t) +COPY_TASK_IMPL(bool, int64_t) +COPY_TASK_IMPL(int16_t, int32_t) +COPY_TASK_IMPL(int16_t, int64_t) +COPY_TASK_IMPL(int32_t, int32_t) +COPY_TASK_IMPL(int32_t, int64_t) +COPY_TASK_IMPL(int64_t, int32_t) +COPY_TASK_IMPL(int64_t, int64_t) +COPY_TASK_IMPL(float, int32_t) +COPY_TASK_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +COPY_TASK_IMPL(float16_t, int32_t) +COPY_TASK_IMPL(float16_t, int64_t) +#endif + +#define GATHER_D_IMPL(type0, type1) \ + GATHER_D_IMPL_DECLARATION(type0, type1) { \ + if (output == NULL || input == NULL || index == NULL || input_shape == NULL || output_shape == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + int max_index = input_shape[dim]; \ + int ret = CheckIndexValue_##type1(index, max_index, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + size_t in_strides[MAX_SHAPE_SIZE]; \ + size_t out_strides[MAX_SHAPE_SIZE]; \ + size_t pos[MAX_SHAPE_SIZE]; \ + ret = InitCalVec(in_strides, out_strides, pos, input_shape, input_shape_size, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + ret = CopyTask_Input_##type0##_Index_##type1(output, input, index, 0, pos, dim, output_shape, output_shape_size, \ + in_strides, out_strides); \ + return ret; \ + } + +GATHER_D_IMPL(bool, int32_t) +GATHER_D_IMPL(bool, int64_t) +GATHER_D_IMPL(int16_t, int32_t) +GATHER_D_IMPL(int16_t, int64_t) +GATHER_D_IMPL(int32_t, int32_t) +GATHER_D_IMPL(int32_t, int64_t) +GATHER_D_IMPL(int64_t, int32_t) +GATHER_D_IMPL(int64_t, int64_t) +GATHER_D_IMPL(float, int32_t) +GATHER_D_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +GATHER_D_IMPL(float16_t, int32_t) +GATHER_D_IMPL(float16_t, int64_t) +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h new file mode 100644 index 00000000..a8270b01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/gather_d_base.h @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_D_BASE_H_ +#define NNACL_GATHER_D_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define GATHER_D(type0, type1, output, input, index, input_shape, input_shape_size, output_shape, output_shape_size, \ + dim) \ + GatherD_Input_##type0##_Index_##type1(output, input, index, input_shape, input_shape_size, output_shape, \ + output_shape_size, dim) + +#define GATHER_D_IMPL_DECLARATION(type0, type1) \ + int GatherD_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, type1 *index, const size_t *input_shape, const size_t input_shape_size, \ + const size_t *output_shape, const size_t output_shape_size, const size_t dim) + +GATHER_D_IMPL_DECLARATION(bool, int32_t); +GATHER_D_IMPL_DECLARATION(bool, int64_t); +GATHER_D_IMPL_DECLARATION(int16_t, int32_t); +GATHER_D_IMPL_DECLARATION(int16_t, int64_t); +GATHER_D_IMPL_DECLARATION(int32_t, int32_t); +GATHER_D_IMPL_DECLARATION(int32_t, int64_t); +GATHER_D_IMPL_DECLARATION(int64_t, int32_t); +GATHER_D_IMPL_DECLARATION(int64_t, int64_t); +GATHER_D_IMPL_DECLARATION(float, int32_t); +GATHER_D_IMPL_DECLARATION(float, int64_t); +#ifdef ENABLE_FP16 +GATHER_D_IMPL_DECLARATION(float16_t, int32_t); +GATHER_D_IMPL_DECLARATION(float16_t, int64_t); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_D_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c new file mode 100644 index 00000000..41226beb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.c @@ -0,0 +1,342 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/minimal_filtering_generator.h" +#include +#include +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void Polynomial(const float *interval, float *m, int degree) { + for (int i = 0; i < degree; ++i) { + float mul = 1; + for (int j = 0; j < degree; ++j) { + if (i == j) { + continue; + } + mul *= (interval[i] - interval[j]); + } + m[i] = mul; + } +} + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree) { + int data_num = (degree + 1) * (degree + 1); + memset(diagonal_matrix, 0, data_num * sizeof(float)); + for (int i = 0; i < degree; ++i) { + for (int j = 0; j < degree; ++j) { + if (j == i) { + diagonal_matrix[i * (degree + 1) + j] = matrix[i]; + } + } + } + diagonal_matrix[data_num - 1] = 1; +} + +void ResidueMatrix(const float *interval, float *b, int row, int col) { + // row : input unit, col : output_unit + // result : matrix b + int len = row * col; + memset(b, 0, len * sizeof(float)); + for (int i = 0; i < row - 1; ++i) { + for (int j = 0; j < col; ++j) { + b[i * col + j] = pow(interval[i], j); + } + } + b[len - 1] = 1; +} + +int LT(const float *poly_array, float *matrix_lt, int n) { + if (n > MAX_LEN) { + return NNACL_ERR; + } + float coefficient_array[MAX_LEN]; // n + float poly[MAX_LEN]; // n + + Polynomial(poly_array, poly, n); + for (int i = 0; i < n; ++i) { + // get coefficient + int index = 1; + memset(coefficient_array, 0, n * sizeof(float)); + coefficient_array[0] = 1; + for (int j = 0; j < n; ++j) { + if (j == i) continue; + float poly_coe = poly_array[j] == 0 ? 0 : -poly_array[j]; + coefficient_array[index] = 1; + for (int k = index - 1; k > 0; --k) { + coefficient_array[k] = coefficient_array[k] * poly_coe + coefficient_array[k - 1]; + } + coefficient_array[0] *= poly_coe; + index++; + } + + // lx[i, 0].nth(j) / f[i] + int setp = i * n; + for (int l = 0; l < n; ++l) { + matrix_lt[setp + l] = coefficient_array[l] / poly[i]; + } + } // matrix L row loop + return NNACL_OK; +} + +void T(const float *poly_array, float *matrix_t, int n) { + memset(matrix_t, 0, n * (n + 1) * sizeof(float)); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n + 1; ++j) { + if (j == i) matrix_t[i * (n + 1) + j] = 1; + if (j == n) { + if (poly_array[i] == 0) { + matrix_t[i * (n + 1) + j] = 0; + } else { + matrix_t[i * (n + 1) + j] = -pow(poly_array[i], n); + } + } + } + } +} + +int B(const float *poly_array, float *matrix_b, int in_unit) { + memset(matrix_b, 0, in_unit * in_unit * sizeof(float)); + int n = in_unit - 1; + if ((n * n) > MAX_LEN || (n * in_unit) > MAX_LEN) { + return NNACL_ERR; + } + float matrix_l[MAX_LEN]; // n * n + float matrix_lt[MAX_LEN]; // n * n + float matrix_t[MAX_LEN]; // n * in_unit + + T(poly_array, matrix_t, n); + if (LT(poly_array, matrix_lt, n) != NNACL_OK) { + return NNACL_ERR; + } + MatrixTranspose(matrix_lt, matrix_l, n, n); + MatrixMultiply(matrix_l, matrix_t, matrix_b, n, n, in_unit); + matrix_b[in_unit * in_unit - 1] = 1; + return NNACL_OK; +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + cnt += c4_channel / 4 - in_channel; + } + } +} +#endif + +void GenerateIntervalArray(float *array, float interval, int degree) { + array[0] = 0; + for (int i = 1; i < degree; ++i) { + int coefficient = pow(-1, i - 1); + array[i] = array[i - 1] + interval * i * coefficient; + } +} + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col) { + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + trans_matrix[i * row + j] = matrix[j * col + i]; + } + } +} + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size) { + int in_unit = out_unit + filter_size - 1; + int degree = in_unit - 1; + if (degree > MAX_LEN || (in_unit * in_unit) > MAX_LEN || (in_unit * filter_size) > MAX_LEN) { + return NNACL_ERR; + } + float polynomial_m[MAX_LEN]; // degree + float diagonal_matrix[MAX_LEN]; // input_unit * input_unit + float inverse_diagonal_matrix[MAX_LEN]; // input_unit * input_unit + + // get diagonal matrix + float interval[MAX_LEN]; // degree + GenerateIntervalArray(interval, coefficient, degree); + Polynomial(interval, polynomial_m, degree); + DiagonalPlusMatrix(polynomial_m, diagonal_matrix, degree); + if (diagonal_matrix[0] < 0) { + for (int i = 0; i < in_unit; ++i) { + if (diagonal_matrix[i] != 0) diagonal_matrix[i] *= -1; + } + } + + // inverse diagonal matrix + for (int j = 0; j < in_unit * in_unit; ++j) { + if (diagonal_matrix[j] != 0) { + inverse_diagonal_matrix[j] = 1.0 / diagonal_matrix[j]; + } else { + inverse_diagonal_matrix[j] = 0; + } + } + + // get matrix A && AT + ResidueMatrix(interval, matrix_a, in_unit, out_unit); + MatrixTranspose(matrix_a, matrix_at, in_unit, out_unit); + + // get matrix B + int ret = B(interval, matrix_bt, in_unit); + if (ret != NNACL_OK) { + return ret; + } + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + MatrixMultiply(diagonal_matrix, matrix_b, matrix_bt, in_unit, in_unit, in_unit); + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + + // get matrix G && GT + float tmp_g[MAX_LEN]; // in_unit * filter_size + ResidueMatrix(interval, matrix_g, in_unit, filter_size); + MatrixTranspose(matrix_g, tmp_g, in_unit, filter_size); + MatrixMultiply(tmp_g, inverse_diagonal_matrix, matrix_gt, filter_size, in_unit, in_unit); + MatrixTranspose(matrix_gt, matrix_g, filter_size, in_unit); + return NNACL_OK; +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, + const float *bias, int m, int k, int n) { + int count = 0; + MS_FLOAT32X4 bias_ptr = MS_MOVQ_F32(0); + if (bias != NULL) { + bias_ptr = MS_LDQ_F32(bias); + } + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + MS_FLOAT32X4 res = MS_MOVQ_F32(0); + for (int i = 0; i < k; i++) { + res = MS_MLAQ_F32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); + } + matrix_c[count] = MS_ADDQ_F32(res, bias_ptr); + count++; + } + } +} +#endif + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) { + if (oc_block == 0) { + return NNACL_PARAM_INVALID; + } + // original weight format : ohwi + int oc_block_num = UP_DIV(batch, oc_block); + int block_stride = channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + // trans_filter = G*g*GT (g represents weight_data) + // separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd + float *tmp_data = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data == NULL) { + return NNACL_ERR; + } + float *trans_out_data = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data == NULL) { + free(tmp_data); + return NNACL_ERR; + } + +#ifndef ENABLE_ARM + float *tmp_data1 = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(trans_out_data); + return NNACL_ERR; + } + float *trans_out_data1 = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(trans_out_data); + return NNACL_ERR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * channel; + for (int i = 0; i < batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM + // tmp_data = g * GT + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // tmp_data1 = (tmp_data)T + PackHWCToWHC(tmp_data, tmp_data1, kernel_unit, input_unit, channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit, kernel_unit, input_unit, channel, + channel * 4); + // trans_out_data = (trans_out_data1)T + PackHWCToWHC(trans_out_data1, trans_out_data, input_unit, input_unit, channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // trans = (tmp * GT)T + MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit, kernel_unit, input_unit, channel, + channel * 4); +#endif + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * channel * input_unit * input_unit, trans_out_data, + channel * input_unit * input_unit * sizeof(float)); + } + } +#ifndef ENABLE_ARM + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h new file mode 100644 index 00000000..44f5bd00 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/minimal_filtering_generator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MINIMAL_FILTERING_GENERATOR_H_ +#define NNACL_MINIMAL_FILTERING_GENERATOR_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif +void Polynomial(const float *interval, float *m, int degree); + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree); + +void ResidueMatrix(const float *interval, float *b, int row, int col); + +int LT(const float *poly_array, float *matrix_lt, int n); + +void T(const float *poly_array, float *matrix_t, int n); + +int B(const float *poly_array, float *matrix_b, int in_unit); + +void GenerateIntervalArray(float *array, float interval, int degree); + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col); + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n); + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size); +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel); + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_MINIMAL_FILTERING_GENERATOR_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c new file mode 100644 index 00000000..d00c33da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.c @@ -0,0 +1,111 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/scatter_nd_binary.h" +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/scatter_nd_binary_simd.h" + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } else { + int *update_int32 = (int *)update; + int *output_int32 = (int *)output; + for (int i = begin; i < end; i++) { + const int *update_data = update_int32 + i * param->unit_size; + int *output_data = output_int32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddInt32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } + return NNACL_OK; +} + +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id) { + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + + int data_type_len = param->data_type_len; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(param->unit_size, data_type_len, NNACL_ERR); + + for (int i = begin; i < end; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_unit_offsets[i], data_type_len, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(i, param->unit_size * data_type_len, NNACL_ERR); + (void)memcpy((int8_t *)output + output_unit_offsets[i] * data_type_len, + (int8_t *)update + i * param->unit_size * data_type_len, param->unit_size * data_type_len); + } + return NNACL_OK; +} + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] = fmaxf(update_data[j], output_data[j]); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h new file mode 100644 index 00000000..098d87d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SCATTER_ND_BINARY_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_H_ + +#include "nnacl_c/scatter_nd_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id); + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SCATTER_ND_BINARY_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in new file mode 100644 index 00000000..25258ede --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/scatter_nd_binary_simd.h.in @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + + static inline int ScatterNDAddFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h new file mode 100644 index 00000000..988ce880 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/sequence_unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/sequence_unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SequenceUnstack(const void *input, void **output, const SequenceUnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c new file mode 100644 index 00000000..c0bca174 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.c @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/slice_base.h" +#include + +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor) { + slice->param_length_ = in_tensor->shape_size_; + + int32_t *begin = (int32_t *)begin_tensor->data_; + int32_t *size = (int32_t *)size_tensor->data_; + + for (int i = 0; i < slice->param_length_; ++i) { + slice->shape_[i] = in_tensor->shape_[i]; + slice->begin_[i] = begin[i]; + slice->size_[i] = size[i] < 0 ? slice->shape_[i] - slice->begin_[i] : size[i]; + slice->end_[i] = slice->begin_[i] + slice->size_[i]; + } + return; +} + +void PadSliceParameterTo8D(SliceStruct *param) { + int32_t begin[DIMENSION_8D]; + int32_t end[DIMENSION_8D]; + int32_t slice_size[DIMENSION_8D]; + int32_t data_shape[DIMENSION_8D]; + for (int32_t i = 0; i < param->param_length_; ++i) { + begin[i] = param->begin_[i]; + end[i] = param->end_[i]; + slice_size[i] = param->size_[i] < 0 ? param->shape_[i] - begin[i] : param->size_[i]; + data_shape[i] = param->shape_[i]; + } + int32_t real_index = param->param_length_ - 1; + for (int32_t i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begin_[i] = begin[real_index]; + param->end_[i] = end[real_index]; + param->size_[i] = slice_size[real_index]; + param->shape_[i] = data_shape[real_index--]; + } else { + param->begin_[i] = 0; + param->end_[i] = 1; + param->size_[i] = 1; + param->shape_[i] = 1; + } + } + param->param_length_ = DIMENSION_8D; +} + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + int count_per_thread = UP_DIV(param->size_[5], thread_num); + int thread_begin = thread_id * count_per_thread; + int thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + + for (int ii = 0; ii < param->size_[0]; ++ii) { + int out_offset0 = ii * out_stride[0]; + int in_offset0 = (ii + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (int jj = 0; jj < param->size_[1]; ++jj) { + int out_offset1 = jj * out_stride[1] + out_offset0; + int in_offset1 = (jj + param->begin_[1]) * in_stride[1] + in_offset0; + for (int kk = 0; kk < param->size_[2]; ++kk) { + int out_offset2 = kk * out_stride[2] + out_offset1; + int in_offset2 = (kk + param->begin_[2]) * in_stride[2] + in_offset1; + for (int ll = 0; ll < param->size_[3]; ++ll) { + int out_offset3 = ll * out_stride[3] + out_offset2; + int in_offset3 = (ll + param->begin_[3]) * in_stride[3] + in_offset2; + for (int i = 0; i < param->size_[4]; ++i) { + int out_offset4 = i * out_stride[4] + out_offset3; + int in_offset4 = (i + param->begin_[4]) * in_stride[4] + in_offset3; + for (int j = thread_begin; j < thread_end; ++j) { + int out_offset5 = j * out_stride[5] + out_offset4; + int in_offset5 = (j + param->begin_[5]) * in_stride[5] + in_offset4; + for (int k = 0; k < param->size_[6]; ++k) { + int out_offset6 = k * out_stride[6] + out_offset5; + int in_offset6 = (k + param->begin_[6]) * in_stride[6] + in_offset5; + memcpy(int8_out + out_offset6 * data_size, int8_in + in_offset6 * data_size, copy_size); + } + } + } + } + } + } + } +} + +static bool WhetherCopyByAxis(const int32_t *begin, const int32_t *end, const int32_t *shape, int dim) { + for (int i = dim + 1; i < DIMENSION_8D; ++i) { + if (begin[i] != 0 || end[i] != shape[i]) return false; + } + return true; +} + +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + bool axis_copy_flag[DIMENSION_8D] = {false}; + for (int i = 0; i < DIMENSION_8D; ++i) { + axis_copy_flag[i] = WhetherCopyByAxis(param->begin_, param->end_, param->shape_, i); + } + int out_offset = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + int in_offset0 = dim0 * in_stride[0] + param->begin_[7]; +#define FAST_COPY_IF_NEED(rank) \ + if (axis_copy_flag[rank]) { \ + int left_block_num = param->end_[rank] - dim##rank; \ + memcpy(int8_out + out_offset * data_size, int8_in + in_offset##rank * data_size, \ + in_stride[rank] * left_block_num * data_size); \ + out_offset += in_stride[rank] * left_block_num; \ + dim##rank += left_block_num; \ + continue; \ + } + FAST_COPY_IF_NEED(0); + for (int dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + int in_offset1 = dim1 * in_stride[1] + in_offset0; + FAST_COPY_IF_NEED(1); + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + int in_offset2 = in_offset1 + dim2 * in_stride[2]; + FAST_COPY_IF_NEED(2); + for (int32_t dim3 = param->begin_[3]; dim3 < param->end_[3]; ++dim3) { + int in_offset3 = in_offset2 + dim3 * in_stride[3]; + FAST_COPY_IF_NEED(3); + for (int32_t dim4 = param->begin_[4]; dim4 < param->end_[4]; ++dim4) { + int in_offset4 = in_offset3 + dim4 * in_stride[4]; + FAST_COPY_IF_NEED(4); + for (int32_t dim5 = param->begin_[5]; dim5 < param->end_[5]; ++dim5) { + int in_offset5 = in_offset4 + dim5 * in_stride[5]; + FAST_COPY_IF_NEED(5); +#undef FAST_COPY_IF_NEED + for (int32_t dim6 = param->begin_[6]; dim6 < param->end_[6]; ++dim6) { + int in_offset6 = in_offset5 + dim6 * in_stride[6]; + memcpy(int8_out + out_offset * data_size, int8_in + in_offset6 * data_size, copy_size); + out_offset += param->size_[7]; + } + } + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h new file mode 100644 index 00000000..8656cb82 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/slice_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SLICE_BASE_H_ +#define NNACL_BASE_SLICE_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor); +void PadSliceParameterTo8D(SliceStruct *param); + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size); +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SLICE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c new file mode 100644 index 00000000..9633a35d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.c @@ -0,0 +1,54 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/space_to_depth_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int output_h = out_shape[kNHWC_H]; + int unit_per_thread = UP_DIV(output_h, param->op_parameter_.thread_num_); + int h_start = unit_per_thread * task_id; + int h_end = MSMIN(h_start + unit_per_thread, output_h); + + int block_size = param->block_size_; + int in_strides[C4NUM]; + int out_strides[C4NUM]; + ComputeStrides(in_shape, in_strides, shape_size); + ComputeStrides(out_shape, out_strides, shape_size); + for (int i = 0; i < out_shape[0]; ++i) { + int64_t in_offset_n = i * in_strides[0]; + int64_t out_offset_n = i * out_strides[0]; + for (int j = h_start; j < h_end; ++j) { + int64_t in_offset_h = in_offset_n + j * block_size * in_strides[1]; + int64_t out_offset_h = out_offset_n + j * out_strides[1]; + for (int k = 0; k < out_shape[2]; ++k) { + int64_t in_offset_w = in_offset_h + k * block_size * in_strides[2]; + int64_t out_offset_w = out_offset_h + k * out_strides[2]; + for (int l = 0; l < block_size; ++l) { + memcpy((int8_t *)output + (out_offset_w + l * block_size * in_strides[DIMENSION_2D]) * param->date_type_len, + (const int8_t *)input + (in_offset_w + l * in_strides[DIMENSION_1D]) * param->date_type_len, + block_size * in_strides[DIMENSION_2D] * param->date_type_len); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h new file mode 100644 index 00000000..916f3c18 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/space_to_depth_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ +#define NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c new file mode 100644 index 00000000..f48922c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.c @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/split_base.h" +#include "nnacl_c/split_parameter.h" +#include +#include "nnacl_c/errorcode.h" + +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size) { + const int8_t *int8_in = (int8_t *)in_data; + + const int num_split = split_param->num_split_; + const int *split_sizes = split_param->split_sizes_; + const int *strides = split_param->strides_; + const int split_dim = split_param->split_dim_; + int in_stride = strides[split_dim]; + + int in_stride_bytes = in_stride * data_size; + + int split_which; + int split_times; + int stride_per_split = in_stride * input_shape[split_dim]; + + split_which = offset % num_split; + split_times = offset / num_split; + const int8_t *src = int8_in + split_times * stride_per_split * data_size; + + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride * data_size; + } + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int split_size = split_sizes[split_which]; + int8_t *int8_out = (int8_t *)out_data[split_which]; + int8_t *dst = int8_out + split_times * in_stride * split_size * data_size; + (void)memcpy(dst, src, split_size * in_stride_bytes); + src += split_size * in_stride * data_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h new file mode 100644 index 00000000..71a5af46 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SPLIT_BASE_H_ +#define NNACL_BASE_SPLIT_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPLIT_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c new file mode 100644 index 00000000..dfca9170 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/split_with_over_lap_base.h" +#include +#include "nnacl_c/errorcode.h" + +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices) { + int start_index = start_indices[slice_idx]; + int end_index = end_indices[slice_idx]; + + int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_; + int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_; + + const char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_; + char *dst_ptr = out_data[slice_idx]; + + for (int i = 0; i < param->outer_total_dim_; i++) { + (void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride); + src_ptr += input_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h new file mode 100644 index 00000000..3b5db1c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/split_with_over_lap_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ +#define NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c new file mode 100644 index 00000000..64454dd5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/stack_base.h" + +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end) { + size_t out_offset = 0; + for (size_t i = outer_start; i < outer_end; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy((char *)output + out_offset, (char *)inputs[j] + i * copy_size, copy_size); + out_offset += copy_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h new file mode 100644 index 00000000..b54a9e75 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/stack_base.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_STACK_BASE_H_ +#define NNACL_BASE_STACK_BASE_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_STACK_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c new file mode 100644 index 00000000..e80328c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/tile_base.h" +#include +#include "nnacl_c/errorcode.h" + +void DoCopyData(const uint8_t *input_data, uint8_t *output_data, size_t size, size_t data_size, size_t multiple) { + uint8_t *out_data = output_data; + for (size_t i = 0; i < multiple; ++i) { + (void)memcpy(out_data, input_data, size * sizeof(uint8_t) * data_size); + out_data += size * data_size; + } +} + +int DoTileOneDimension(uint8_t *input_data, uint8_t *output_data, size_t dim, const TileStruct *tile) { + int src_dim_size = tile->in_shape_[dim]; + if (dim == tile->in_dim_ - 1) { + DoCopyData(input_data, output_data, src_dim_size, tile->data_size_, tile->multiples_[dim]); + return NNACL_OK; + } + for (int i = 0; i < src_dim_size; ++i) { + for (int j = 0; j < tile->multiples_[dim]; ++j) { + int in_pos = tile->in_strides_[dim] * i; + int out_pos = tile->out_strides_[dim] * (i + j * src_dim_size); + DoTileOneDimension(input_data + in_pos * tile->data_size_, output_data + out_pos * tile->data_size_, dim + 1, + tile); + } + } + return NNACL_OK; +} + +void Tile(void *input_data, void *output_data, const TileStruct *tile) { + DoTileOneDimension((uint8_t *)input_data, (uint8_t *)output_data, 0, tile); +} + +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile) { + uint8_t *out_data = output_data; + uint8_t *in_data = input_data; + size_t dst_one_row_size = tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + for (size_t i = begin; i < end; ++i) { + uint8_t *src = in_data + i * tile->fast_stride_ * tile->data_size_; + uint8_t *dst = out_data + i * tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + size_t offset = tile->fast_stride_ * tile->data_size_; + (void)memcpy(dst, src, offset); + // copy size double each time + while (2 * offset <= dst_one_row_size) { + (void)memcpy(dst + offset, dst, offset); + offset *= 2; + } + if (2 * offset > dst_one_row_size) { + (void)memcpy(dst + offset, dst, dst_one_row_size - offset); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h new file mode 100644 index 00000000..cc84e8d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/tile_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TILE_BASE_H_ +#define NNACL_BASE_TILE_BASE_H_ + +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/tile_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Tile(void *input_data, void *output_data, const TileStruct *tile); +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TILE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c new file mode 100644 index 00000000..55a53d58 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.c @@ -0,0 +1,274 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/transpose_base.h" +#include "nnacl_c/errorcode.h" + +#define TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + void TransposeDim2##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * output1; \ + int stride0_i = i * 1 * stride0; \ + for (int j = 0; j < output1; ++j) { \ + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; \ + } \ + } \ + } + +#define TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + void TransposeDim3##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + void TransposeDim4##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = \ + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + void TransposeDim5##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + void TransposeDim6##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int stride5 = strides[perm[5]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int out_stride4 = out_strides[4]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + const int output5 = output_shape[5]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + int out_stride4_n = n * out_stride4; \ + int stride4_n = n * stride4; \ + for (int g = 0; g < output5; ++g) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; \ + } \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_DIMS(TYPE, NAME) \ + void TransposeDims##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param, int task_id, int thread_num) { \ + NNACL_CHECK_NULL_RETURN_VOID(in_data); \ + NNACL_CHECK_NULL_RETURN_VOID(out_data); \ + NNACL_CHECK_NULL_RETURN_VOID(output_shape); \ + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); \ + NNACL_CHECK_ZERO_RETURN(thread_num); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int num_axes = transpose_param->num_axes_; \ + size_t data_size = (*out_strides) * output_shape[0]; \ + size_t offset_size = UP_DIV(data_size, thread_num); \ + size_t task_offset = offset_size * task_id; \ + int count = data_size - task_offset; \ + if (count <= 0) { \ + return; \ + } \ + count = MSMIN(offset_size, count); \ + for (int idx = task_offset; idx < task_offset + count; ++idx) { \ + int pos = idx; \ + int output_idx = 0; \ + int input_idx = 0; \ + for (int i = 0; i < num_axes; ++i) { \ + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); \ + int position = pos / *(out_strides + i); \ + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; \ + output_idx += (position * out_stride); \ + input_idx += (position * strides[perm[i]]); \ + pos -= position * (*(out_strides + i)); \ + } \ + out_data[output_idx] = in_data[input_idx]; \ + } \ + } + +#define DOTRANSPOSE(TYPE, NAME) \ + int DoTranspose##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param) { \ + NNACL_CHECK_NULL_RETURN_ERR(in_data); \ + NNACL_CHECK_NULL_RETURN_ERR(out_data); \ + NNACL_CHECK_NULL_RETURN_ERR(output_shape); \ + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int data_size = transpose_param->data_num_ * sizeof(TYPE); \ + int num_axes = transpose_param->num_axes_; \ + bool needTranspose = false; \ + for (int i = 1; i < num_axes; ++i) { \ + if (perm[i] - perm[i - 1] != 1) { \ + needTranspose = true; \ + break; \ + } \ + } \ + if (!needTranspose) { \ + (void)memcpy(out_data, in_data, data_size); \ + return NNACL_OK; \ + } \ + for (int i = 0; i < num_axes; ++i) { \ + if (perm[i] < 0) { \ + return NNACL_PARAM_INVALID; \ + } \ + } \ + if (num_axes == 2) { \ + TransposeDim2##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 3) { \ + TransposeDim3##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 4) { \ + TransposeDim4##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 5) { \ + TransposeDim5##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 6) { \ + TransposeDim6##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else { \ + return NNACL_ERR; \ + } \ + return NNACL_OK; \ + } + +#define TRANSPOSE_TEMPLATE(TYPE, NAME) \ + TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + TRANSPOSE_DIMS(TYPE, NAME) \ + DOTRANSPOSE(TYPE, NAME) + +TRANSPOSE_TEMPLATE(uint8_t, UInt8) +TRANSPOSE_TEMPLATE(uint16_t, UInt16) +TRANSPOSE_TEMPLATE(uint32_t, UInt32) +TRANSPOSE_TEMPLATE(uint64_t, UInt64) +TRANSPOSE_TEMPLATE(int16_t, Int16) +TRANSPOSE_TEMPLATE(int32_t, Int32) +TRANSPOSE_TEMPLATE(int64_t, Int64) +TRANSPOSE_TEMPLATE(double, Float64) +TRANSPOSE_TEMPLATE(bool, Bool) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h new file mode 100644 index 00000000..67fa636f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/transpose_base.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TRANSPOSE_BASE_H_ +#define NNACL_BASE_TRANSPOSE_BASE_H_ + +#include "nnacl_c/transpose_parameter.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param); + +void TransposeDimsUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TRANSPOSE_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c new file mode 100644 index 00000000..d962914c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/base/unsorted_segment_sum_base.h" +#include "nnacl_c/errorcode.h" + +#define UNSORTEDSEGMENTSUM(type, type1) \ + int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \ + type *output, int output_dim0, int output_dim1) { \ + NNACL_CHECK_NULL_RETURN_ERR(input); \ + NNACL_CHECK_NULL_RETURN_ERR(indices); \ + NNACL_CHECK_NULL_RETURN_ERR(output); \ + if (input_dim1 == 0) { \ + return NNACL_ERR; \ + } \ + for (int i = 0; i < unit_num; ++i) { \ + int j = i / input_dim1; \ + int k = i % input_dim1; \ + \ + type1 index = indices[j]; \ + if (index < 0 || index >= output_dim0) { \ + continue; \ + } \ + type1 output_index = index * output_dim1 + k; \ + output[output_index] += input[i]; \ + } \ + return NNACL_OK; \ + } + +UNSORTEDSEGMENTSUM(int, int) +UNSORTEDSEGMENTSUM(float, int) +UNSORTEDSEGMENTSUM(int, int64_t) +UNSORTEDSEGMENTSUM(float, int64_t) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h new file mode 100644 index 00000000..e05d62f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unsorted_segment_sum_base.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ +#define NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \ + UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) +int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices, + float *output, int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c new file mode 100644 index 00000000..d286de83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/base/unstack_base.h" + +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size) { + NNACL_CHECK_NULL_RETURN_VOID(input); + NNACL_CHECK_NULL_RETURN_VOID(output); + NNACL_CHECK_NULL_RETURN_VOID(para); + const int8_t *in_addr = (int8_t *)input; + for (int j = 0; j < para->num_; j++) { + int8_t *out_addr = (int8_t *)output[j]; + int out_offset = 0; + for (int i = 0; i < para->pre_dims_; i++) { + int in_offset = i * para->axis_dim_ * para->after_dims_ + j * para->after_dims_; + (void)memcpy(out_addr + out_offset * data_size, in_addr + in_offset * data_size, para->after_dims_ * data_size); + out_offset += para->after_dims_; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h new file mode 100644 index 00000000..9b0856de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/base/unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSTACK_BASE_H_ +#define NNACL_BASE_UNSTACK_BASE_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_UNSTACK_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h new file mode 100644 index 00000000..ada3dd0b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batch_to_space_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_PARAMETER_H_ +#define NNACL_BATCH_TO_SPACE_PARAMETER_H_ + +#include +#include "nnacl_c/op_base.h" + +#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2 + +typedef struct BatchToSpaceParameter { + OpParameter op_parameter_; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[COMM_SHAPE_SIZE]; +} BatchToSpaceParameter; + +#endif // NNACL_BATCH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h new file mode 100644 index 00000000..b140e953 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/batchnorm_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BATCHNORM_PARAMETER_H_ +#define NNACL_BATCHNORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BatchNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; + float momentum_; +} BatchNormParameter; + +#endif // NNACL_BATCHNORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h new file mode 100644 index 00000000..6e934aad --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/broadcast_to_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BROADCAST_TO_PARAMETER_H_ +#define NNACL_BROADCAST_TO_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BroadcastToParameter { + OpParameter op_parameter_; + int shape_[MAX_SHAPE_SIZE]; + size_t shape_size_; +} BroadcastToParameter; + +typedef struct BroadcastShapeInfo { + int input_shape_[MAX_SHAPE_SIZE]; + int input_shape_size_; + int output_shape_[MAX_SHAPE_SIZE]; + int output_shape_size_; +} BroadcastShapeInfo; + +#endif // NNACL_BROADCAST_TO_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h new file mode 100644 index 00000000..ea5e85e8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/call_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CALL_PARAMETER_H_ +#define NNACL_CALL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct CallParameter { + OpParameter op_parameter_; + bool is_tail_call; +} CallParameter; + +#endif // NNACL_CALL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h new file mode 100644 index 00000000..8f3fbf64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/clip_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CLIP_PARAMETER_H_ +#define NNACL_CLIP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ClipParameter { + OpParameter op_parameter_; + float min_val_; + float max_val_; +} ClipParameter; + +#endif // NNACL_CLIP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c new file mode 100644 index 00000000..cad7ea19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/common_func.h" + +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; +} + +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; +} + +int Offset4d(const int *shape, const int *dims) { return Offset(shape, dims[0], dims[1], dims[2], dims[3]); } + +int64_t Offset6d(const int *shape, const int *dims) { + return ((OffsetComm(shape, dims[0], dims[1], dims[2]) + dims[3]) * shape[4] + dims[4]) * shape[5]; +} + +int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } + +int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h new file mode 100644 index 00000000..7463a30f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/common_func.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_COMMON_FUNC_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int8_t MinInt8(int8_t a, int8_t b); +int8_t MaxInt8(int8_t a, int8_t b); +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2); +int Offset4d(const int *shape, const int *dims); +int64_t Offset6d(const int *shape, const int *dims); + +static inline bool isAddOverflow(int32_t x, int32_t y) { + int32_t sum = x + y; + return (x > 0 && y > 0 && sum < 0) || (x < 0 && y < 0 && sum > 0); +} + +static inline bool isMulOverflow(int32_t x, int32_t y) { + int32_t p = x * y; + return (x != 0) && (p / x != y); +} + +static inline int GetStride(int *strides, const int *shape, int length) { + if (length <= 0) { + return 1; + } + int stride = 1; + for (int i = length - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + return stride; +} +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_NNACL_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h new file mode 100644 index 00000000..c902be39 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/concat_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_CONCAT_PARAMETER_H_ +#define MINDSPORE_NNACL_CONCAT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ConcatParameter { + OpParameter op_parameter_; + ConcatQuantArg quant_arg_; + int axis_; +} ConcatParameter; + +#endif // MINDSPORE_NNACL_CONCAT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h new file mode 100644 index 00000000..54174e88 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/constant_of_shape_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ +#define NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ConstantOfShapeParameter { + OpParameter op_parameter_; + union value_ { + float f32_value_; + int32_t int32_value_; + bool bool_value_; + } value_; + int data_type_; + int element_size_; +} ConstantOfShapeParameter; + +#endif // NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h new file mode 100644 index 00000000..edaa02e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv3d_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONV3D_PARAMETER_H_ +#define NNACL_CONV3D_PARAMETER_H_ + +#include +#include "nnacl_c/op_base.h" + +typedef struct Conv3DParameter { + OpParameter op_parameter_; +} Conv3DParameter; + +#endif // NNACL_CONV3D_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h new file mode 100644 index 00000000..a3c301a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/conv_parameter.h @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CONV_PARAMETER_H_ +#define NNACL_CONV_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ConvParameter { + OpParameter op_parameter_; + ConvQuantArg conv_quant_arg_; + + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int group_; + int tile_num_; /* # */ + int input_batch_; /* # */ + int input_h_; /* # */ + int input_w_; /* # */ + int input_channel_; + int output_batch_; /* # */ + int output_h_; /* # */ + int output_w_; /* # */ + int output_channel_; + int thread_num_; /* # */ + int input_unit_; /* # */ + int output_unit_; /* # */ + PadType pad_mode_; + ActType act_type_; + int channel_multiplie_; /* # */ + int output_padding_w_; /* # */ + int output_padding_h_; /* # */ + int out_format_; + + bool dynamic_shape_; +} ConvParameter; + +typedef struct ConvComputeParam { + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + + int in_n_; + int in_h_; + int in_w_; + int in_c_; + int out_n_; + int out_h_; + int out_w_; + int out_c_; + + int in_hw_; + int out_hw_; + int kernel_hw_; + int tile_num_; +} ConvComputeParam; + +typedef struct SlidingWindowParam { + int left_; + int right_; + int top_; + int bottom_; + int c_block_; + int block_channel_; + int ic_align_; + int out_step_; + int out_h_step_; + int out_c_step_; + int out_w_step_; + int out_block_step_; + int in_step_; + int in_h_step_; + int in_sh_step_; // stride H + int in_sw_step_; // stride W + int in_kh_step_; // kernel H + int in_kw_step_; // kernel W + int kernel_step_; +} SlidingWindowParam; + +typedef struct ConvDwCalcParam { + void *num_pixels_; + void *out_w_start_; + void *out_w_end_; + int first_calc_kw_; +} ConvDwCalcParam; + +#define OUPUT_UNIT 2 +#define DECONV_WINOGRAD_DEFAULT_UNIT 3 /* # */ +#define DECONV_WINOGRAD_DEFAULT_TILE 8 /* # */ +#define DECONV_WINOGRAD_BUFFER_COUNT 8 /* # */ +typedef struct DeConvWg { /* # */ + void *b_buffer_; + void *AT_; + void *BT_; + + int kh_; + int kw_; + + int k_; + int i_; + int o_; +} DeConvWg; + +typedef struct DeConvWgABuffer { /* # */ + bool buf_init_; + void *middle_buffer_; + void *dest_buffer_; +} DeConvWgABuffer; + +typedef struct DeConvComputeUnit { /* # */ + void *weight_; + void *tmp_buffer_; + int w_start_; + int h_start_; + int w_size_; + int h_size_; + bool use_winograd_; + DeConvWg winograd_; +} DeConvComputeUnit; + +typedef struct DeConvParam { /* # */ + DeConvComputeUnit *compute_units_; + int compute_size_; + DeConvWgABuffer a_buffer_[DECONV_WINOGRAD_BUFFER_COUNT]; + int input_plane_; + int output_plane_; + int kernel_plane_; + int ic_div_; + int oc_div_; + int ic_up_; + int oc_up_; + int thread_num_; + int in_tile_count_; + int in_tile_h_count_; + int in_tile_w_count_; + int out_tile_h_; + int out_tile_w_; +} DeConvParam; + +#endif // NNACL_CONV_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h new file mode 100644 index 00000000..ebae6d01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/crop_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CROP_PARAMETER_H_ +#define NNACL_CROP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct CropParameter { + OpParameter op_parameter_; + int64_t axis_; + int offset_size_; + int64_t offset_[COMM_SHAPE_SIZE]; +} CropParameter; + +#endif // NNACL_CROP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h new file mode 100644 index 00000000..1e58b8dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/cumsum_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CUMSUM_PARAMETER_H_ +#define NNACL_CUMSUM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CumSumParameter { + OpParameter op_parameter_; + bool reverse_; + bool exclusive_; + int axis_; +} CumsumParameter; + +#endif // NNACL_CUMSUM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h new file mode 100644 index 00000000..e3970e76 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_gru_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_GRU_PARAMETER_H_ +#define NNACL_CUSTOM_GRU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomGruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int num_step; + int batch_size; + int input_size; + int hidden_size; +} CustomGruParameter; + +#endif // NNACL_CUSTOM_GRU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h new file mode 100644 index 00000000..3a303404 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_is_inf_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomIsInfParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomIsInfParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h new file mode 100644 index 00000000..81f68ef8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_masked_fill_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct CustomMaskedFillParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomMaskedFillParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h new file mode 100644 index 00000000..0e56cdf3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/custom_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_PARAMETER_H_ +#define NNACL_CUSTOM_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +#define MAX_STR_LEN 64 +#define MAX_ATTR_NUM 8 + +typedef struct CustomParameter { + OpParameter op_parameter_; + char type[MAX_STR_LEN]; + char attr_name[MAX_ATTR_NUM][MAX_STR_LEN]; + char *attr_data[MAX_ATTR_NUM]; + int attr_num; +} CustomParameter; +#endif // NNACL_CUSTOM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h new file mode 100644 index 00000000..571e22db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/depth_to_space_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#define NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DepthToSpaceParameter { + OpParameter op_parameter_; + int32_t block_size_; + int32_t mode_; +} DepthToSpaceParameter; + +#endif // NNACL_DEPTH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h new file mode 100644 index 00000000..74828b37 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/detection_post_process_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#define NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DetectionPostProcessParameter { + OpParameter op_parameter_; + float h_scale_; + float w_scale_; + float x_scale_; + float y_scale_; + float nms_iou_threshold_; + float nms_score_threshold_; + int64_t max_detections_; + int64_t detections_per_class_; + int64_t max_classes_per_detection_; + int64_t num_classes_; + bool use_regular_nms_; + bool out_quantized_; + + float *anchors_; + + void *decoded_boxes_; + void *nms_candidate_; + void *indexes_; + void *scores_; + void *all_class_indexes_; + void *all_class_scores_; + void *single_class_indexes_; + void *selected_; +} DetectionPostProcessParameter; + +#endif // NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h new file mode 100644 index 00000000..978a0365 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/dynamic_quant_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#define NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct DynamicQuantParameter { + OpParameter op_parameter_; + bool symmetric_; + int dst_type_; + int axis_num_; + int prefer_axes_[MAX_SHAPE_SIZE]; +} DynamicQuantParameter; + +#endif // NNACL_DYNAMIC_QUANT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c new file mode 100644 index 00000000..7e1a2844 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.c @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/errorcode.h" +#include + +void InitNNACLKernelErrorCode(char **nnacl_kernel_error_msg) { + nnacl_kernel_error_msg[NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID] = + "In CropAndResize, the value of box idx should match: [0, batch)."; + nnacl_kernel_error_msg[NNACL_WHERE_INPUT_NUM_INVALID] = "Invalid input number. Where op input number support 1 or 3."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_DATA_TYPE_ERROR] = + "Invalid input data type. Where op input data type support int32 fp32 and bool."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_NUM_INVALID] = + "The length of three inputs are not equal to 1 or length of output, which is unacceptable."; + nnacl_kernel_error_msg[NNACL_WHERE_INVALID_OUT_NUM] = "The element number invalid."; + nnacl_kernel_error_msg[NNACL_WHERE_NUM_MAX_INVALID] = "Inputs' length are zero"; + nnacl_kernel_error_msg[NNACL_ERR] = "NNACL common error."; +} + +char *NNACLErrorMsg(int error_code) { + static char nnacl_kernel_error_msg[NNACL_COMMON_END][MAX_MSG_LEN]; + static bool inited = false; + if (!inited) { + inited = true; + InitNNACLKernelErrorCode((char **)nnacl_kernel_error_msg); + } + + if (error_code > NNACL_OK && error_code < NNACL_COMMON_END) { + return nnacl_kernel_error_msg[error_code]; + } + + return "NNACL execute error!"; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h new file mode 100644 index 00000000..a7c6190b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/errorcode.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ERRORCODE_H_ +#define NNACL_ERRORCODE_H_ + +#include + +#define MAX_MSG_LEN 256 + +typedef enum ErrorCodeCommonEnum { + NNACL_OK = 0, + NNACL_ERR = 1, + NNACL_NULL_PTR, + NNACL_PARAM_INVALID, + NNACL_INFER_INVALID, + NNACL_INPUT_TENSOR_ERROR, + NNACL_OUTPUT_TENSOR_ERROR, + NNACL_INPUT_OUTPUT_DATA_TYPE_UNMATCH, + NNACL_FORMAT_ERROR, + NNACL_BUFFER_OVERFLOW, + NNACL_TENSOR_SIZE_INVALID, + NNACL_UNSUPPORTED_DATA_TYPE, + NNACL_UNSUPPORTED_FORMAT, + NNACL_MALLOC_BUFFER_FAILED, + NNACL_MALLOC_SIZE_INVALID, + NNACL_DISABLE_FP16, + NNACL_ADDN_SHAPE_UNMATCH, + NNACL_ACTIVATION_TYPE_INVALID, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH, + NNACL_ARITHMETIC_SHAPE_INVALID, + NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT, + NNACL_ARG_MIN_MAX_AXIS_INVALID, + NNACL_BIAS_ADD_SHAPE_NOT_MATCH, + NNACL_BIAS_ADD_SHAPE_OVERFLOW, + NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID, + NNACL_BATCH_TO_SPACE_CROP_INVALID, + NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID, + NNACL_CLIP_DATA_TYPE_INVALID, + NNACL_CLIP_MINMAX_VALUE_INVALID, + NNACL_CONCAT_AXIS_INVALID, + NNACL_CONCAT_F16_INVALID_DATA_TYPE, + NNACL_CONCAT_F16_OUTPUT_DATA_INVALID, + NNACL_CONCAT_SHAPE_INVALID, + NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH, + NNACL_CONVOLUTION_INPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_KERNEL_HW_OVERFLOW, + NNACL_CONVOLUTION_OUTPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID, + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID, + NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT, + NNACL_CONVOLUTION_WEIGHT_DATA_INVALID, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID, + NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID, + NNACL_DECONV_RESIZE_OC_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK, + NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID, + NNACL_DEPTH_TO_SPACE_INVALID_MODE, + NNACL_ELTWISE_INVALID_MOD, + NNACL_FILL_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_NO_CHANGE, + NNACL_FUSED_BATCH_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED, + NNACL_FUSED_BATCH_TRAIN_DATA_INVALID, + NNACL_FUSED_BATCH_TRAIN_PARAM_DATA_INVALID, + NNACL_GATHER_INDICES_DATA_TYPE_INVALID, + NNACL_GATHER_INDICES_VALUE_INVALID, + NNACL_GATHER_AXIS_INVALID, + NNACL_GATHER_INPUT_TENSOR_INVALID, + NNACL_GATHER_OUTPUT_TENSOR_INVALID, + NNACL_GATHER_D_AXIS_INVALID, + NNACL_GATHER_ND_COUNT_INVALID, + NNACL_GATHER_ND_INDICES_RANK_INVALID, + NNACL_GATHER_ND_INDICES_SHAPE_INVALID, + NNACL_GROUP_CONVOLUTION_GROUP_INVALID, + NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID, + NNACL_GROUP_CONVOLUTION_SHAPE_INVALID, + NNACL_GROUP_NORM_NUM_GROUPS_INVALID, + NNACL_GROUP_NORM_SHAPE_SIZE_INVALID, + NNACL_GROUP_NORM_FORMAT_INVALID, + NNACL_SOFTMAX_AXIS_INVALID, + NNACL_MATMUL_ACT_TYPE_INVALID, + NNACL_MATMUL_BIAS_INVALID, + NNACL_NON_ZERO_SHAPE_INVALID, + NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID, + NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA, + NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH, + NNACL_ONE_HOT_AXIS_INVALID, + NNACL_ONE_HOT_OUTER_SIZE_INVALID, + NNACL_ONE_HOT_INNER_SIZE_INVALID, + NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_PAD_SHAPE_INVALID, + NNACL_PAD_PADDING_VALID_INVALID, + NNACL_PAD_MIRROR_PAD_SIZE_INVALID, + NNACL_POW_INVALID_DATA_TYPE, + NNACL_PRELU_SLOPE_NUM_INVALID, + NNACL_PRIOR_BOX_VALUE_INVALID, + NNACL_PRIOR_BOX_RATIO_INVALID, + NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID, + NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID, + NNACL_LAYER_NORM_OUTPUT_NUM_INVALID, + NNACL_REDUCE_AXIS_SIZE_ERROR, + NNACL_REDUCE_AXES_TENSOR_ERROR, + NNACL_REDUCE_UNSUPPORTED_DATA_TYPE, + NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID, + NNACL_REDUCE_COEFF_DATA_TYPE_INVALID, + NNACL_REVERSE_AXIS_INVALID, + NNACL_REVERSE_AXIS_VALUE_INVALID, + NNACL_REVERSE_DATA_SIZE_INVALID, + NNACL_REVERSE_NUM_AXIS_INVALID, + NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH, + NNACL_SCALE_UNSUPPORT_ACT_TYPE, + NNACL_SCALE_SCALE_SHAPE_UNMATCH, + NNACL_SCALE_INPUT_NUM_INVALID, + NNACL_STACK_TENSOR_SHAPE_INVALID, + NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE, + NNACL_STRIDED_SLICE_INVALID_DATA_SIZE, + NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE, + NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD, + NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D, + NNACL_SPLICE_SHAPE_INVALID, + NNACL_TILE_INPUT_SHAPE_INVALID, + NNACL_TILE_SECOND_INPUT_NUM_INVALID, + NNACL_TILE_SECOND_INPUT_VALUE_INVALID, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID, + NNACL_TILE_RESIZE_IN_RUNTIME_FAILED, + NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR, + NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID, + NNACL_TRIU_INPUT_DIMS_INVALID, + NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE, + NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID, + NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID, + NNACL_TRANSPOSE_PERM_DIMS_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID, + NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED, + NNACL_WHERE_INPUT_NUM_INVALID, + NNACL_WHERE_CONDITION_DATA_TYPE_ERROR, + NNACL_WHERE_CONDITION_NUM_INVALID, + NNACL_WHERE_INVALID_OUT_NUM, + NNACL_WHERE_NUM_MAX_INVALID, + NNACL_WHERE_BROAD_CAST_FAILED, + NNACL_COMMON_END +} ErrorCodeCommonEnum; + +typedef enum ErrorCodeFp32OpEnum { + NNACL_ERRCODE_OP_FP32_START = 10000, + NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, + NNACL_ERRCODE_REVERSE_MALLOC, + NNACL_ERRCODE_SQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_DIVISOR_ZERO, + NNACL_ERRCODE_INDEX_OUT_OF_RANGE, + NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR, + NNACL_ERRCODE_OP_FP32_END = 19999 +} ErrorCodeFp32OpEnum; + +typedef enum ErrorCodeFp16OpEnum { + NNACL_ERRCODE_OP_FP16_START = 20000, + NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR, + NNACL_ERRCODE_OP_FP16_END = 29999 +} ErrorCodeFp16OpEnum; + +typedef enum ErrorCodeUint8OpEnum { + NNACL_ERRCODE_OP_UINT8_START = 30000, + NNACL_ERRCODE_OP_UINT8_END = 39999 +} ErrorCodeUint8OpEnum; + +typedef enum ErrorCodeInt8OpEnum { + NNACL_ERRCODE_OP_INT8_START = 40000, + NNACL_ERRCODE_ADD_OVERFLOW, + NNACL_ERRCODE_MUL_OVERFLOW, + NNACL_ERRCODE_OP_INT8_END = 49999 +} ErrorCodeInt8OpEnums; + +#ifdef __cplusplus +extern "C" { +#endif +char *NNACLErrorMsg(int error_code); +#ifdef __cplusplus +} +#endif +#endif // NNACL_ERRORCODE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h new file mode 100644 index 00000000..1c1dd12e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/exp_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_EXP_PARAMETER_H_ +#define NNACL_EXP_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ExpParameter { + OpParameter op_parameter_; + float base_; + float scale_; + float shift_; +} ExpParameter; + +#endif // NNACL_EXP_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..a3c760c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c @@ -0,0 +1,533 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..f59b52db --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c @@ -0,0 +1,781 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..c6b24cba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c @@ -0,0 +1,573 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..9452a5d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c @@ -0,0 +1,844 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..14fa99e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..6975229c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c @@ -0,0 +1,908 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6), [dst_9] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [dst_9] "r"(dst_9), [src_3] "r"(src_3), [src_6] "r"(src_6), + [src_9] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..5b51eb3f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..78859199 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..c2b38c19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..a7c32c93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c @@ -0,0 +1,278 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000..a6a2faeb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c @@ -0,0 +1,318 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000..c9a2b59e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c @@ -0,0 +1,358 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..8db14458 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..667f8698 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c @@ -0,0 +1,261 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..b494d316 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..99cce4e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c @@ -0,0 +1,387 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000..f396dce2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000..7db94d09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c @@ -0,0 +1,513 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d576a986 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..ffe63a35 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..e94ec339 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..976212df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c @@ -0,0 +1,496 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000..42c46be6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c @@ -0,0 +1,583 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000..450dd072 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c @@ -0,0 +1,669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..1c592474 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c @@ -0,0 +1,283 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..527bcd9f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c @@ -0,0 +1,392 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..0443d216 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c @@ -0,0 +1,501 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..a88f4329 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c @@ -0,0 +1,611 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000..10d52e6f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c @@ -0,0 +1,720 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000..48f890f5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -0,0 +1,830 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..097ffc2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c @@ -0,0 +1,323 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..4c0c4315 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c @@ -0,0 +1,455 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..e5d0f283 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c @@ -0,0 +1,588 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..fc0a6bc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c @@ -0,0 +1,720 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000..9ec37913 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c @@ -0,0 +1,853 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..05f6f849 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c @@ -0,0 +1,363 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..f59cc5f0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c @@ -0,0 +1,518 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..f083e994 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c @@ -0,0 +1,674 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000..71d8d74d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c @@ -0,0 +1,830 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [src_3] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..a53f0dc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c @@ -0,0 +1,408 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..9d4a6694 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c @@ -0,0 +1,587 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..cc4b94c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c @@ -0,0 +1,765 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..fb6fb037 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c @@ -0,0 +1,448 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..93d37efc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c @@ -0,0 +1,650 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c new file mode 100644 index 00000000..b04839be --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c @@ -0,0 +1,852 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c new file mode 100644 index 00000000..da7707ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c @@ -0,0 +1,488 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000..95e0308e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c @@ -0,0 +1,713 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [dst_0] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_3] "r"(dst_3), + [dst_6] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + : + : [src_0] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [depth] "r"(depth), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst_0] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_3] "r"(dst_3), [dst_6] "r"(dst_6), [src_3] "r"(src_3), [src_6] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..16035268 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..0470cc51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,303 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..dcc963de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..83a66a8a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,325 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..8e39ab64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,345 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + __m256 src110 = _mm256_set1_ps(*(src + 88)); + dst11 = _mm256_fmadd_ps(dst11, src110, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + __m256 src111 = _mm256_set1_ps(*(src + 89)); + dst11 = _mm256_fmadd_ps(dst11, src111, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + __m256 src112 = _mm256_set1_ps(*(src + 90)); + dst11 = _mm256_fmadd_ps(dst11, src112, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + __m256 src113 = _mm256_set1_ps(*(src + 91)); + dst11 = _mm256_fmadd_ps(dst11, src113, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + __m256 src114 = _mm256_set1_ps(*(src + 92)); + dst11 = _mm256_fmadd_ps(dst11, src114, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + __m256 src115 = _mm256_set1_ps(*(src + 93)); + dst11 = _mm256_fmadd_ps(dst11, src115, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + __m256 src116 = _mm256_set1_ps(*(src + 94)); + dst11 = _mm256_fmadd_ps(dst11, src116, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + __m256 src117 = _mm256_set1_ps(*(src + 95)); + dst11 = _mm256_fmadd_ps(dst11, src117, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); + _mm256_store_ps(dst + 0 * src_stride + 88, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..e478150a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,347 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "vmovups 352(%[dst]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "vmovaps 0(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vbroadcastss 352(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vbroadcastss 353(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vbroadcastss 354(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vbroadcastss 355(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vbroadcastss 356(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vbroadcastss 357(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vbroadcastss 358(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vbroadcastss 359(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + "vmovups %%ymm11, 352(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..87a830fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..5b592d4b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..548f143c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..ca1be9fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..c23a6ff6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst3 = _mm256_fmadd_ps(dst3, src00, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst3 = _mm256_fmadd_ps(dst3, src01, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst3 = _mm256_fmadd_ps(dst3, src02, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst3 = _mm256_fmadd_ps(dst3, src03, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst3 = _mm256_fmadd_ps(dst3, src04, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst3 = _mm256_fmadd_ps(dst3, src05, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst3 = _mm256_fmadd_ps(dst3, src06, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst3 = _mm256_fmadd_ps(dst3, src07, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); + _mm256_store_ps(dst + 3 * src_stride + 0, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..e327580a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,173 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "vmovups 0(%[dst_4]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "vmovaps 96(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vmovaps 0(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 64(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 192(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vmovaps 256(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vmovaps 384(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 448(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 544(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 576(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vmovaps 640(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 672(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 736(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vmovaps 768(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 800(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 832(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 864(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vmovaps 896(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 928(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 960(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 992(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm3, 0(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..a795d14b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..97a62290 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..55787a69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst1; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..d1f7cdce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..1de0370c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst1; + __m256 dst3; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..2c5e9306 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..0f9dbfa6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst6; + __m256 dst1; + __m256 dst3; + __m256 dst5; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst6 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + dst7 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst6 = _mm256_fmadd_ps(dst6, src00, weight30); + dst7 = _mm256_fmadd_ps(dst7, src10, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst6 = _mm256_fmadd_ps(dst6, src01, weight31); + dst7 = _mm256_fmadd_ps(dst7, src11, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst6 = _mm256_fmadd_ps(dst6, src02, weight32); + dst7 = _mm256_fmadd_ps(dst7, src12, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst6 = _mm256_fmadd_ps(dst6, src03, weight33); + dst7 = _mm256_fmadd_ps(dst7, src13, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst6 = _mm256_fmadd_ps(dst6, src04, weight34); + dst7 = _mm256_fmadd_ps(dst7, src14, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst6 = _mm256_fmadd_ps(dst6, src05, weight35); + dst7 = _mm256_fmadd_ps(dst7, src15, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst6 = _mm256_fmadd_ps(dst6, src06, weight36); + dst7 = _mm256_fmadd_ps(dst7, src16, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst6 = _mm256_fmadd_ps(dst6, src07, weight37); + dst7 = _mm256_fmadd_ps(dst7, src17, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); + _mm256_store_ps(dst + 3 * src_stride + 0, dst6); + _mm256_store_ps(dst + 3 * src_stride + 8, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..142ff106 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "vmovups 0(%[dst_4]), %%ymm6\n" + "vmovups 32(%[dst_4]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "vmovaps 96(%[bias]), %%ymm6\n" + "vmovaps 96(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vmovaps 0(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 32(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 96(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vmovaps 128(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 192(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 224(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 288(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 320(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vmovaps 384(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 416(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 480(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vmovaps 512(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 576(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 608(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 672(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 704(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vmovaps 768(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 800(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 832(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 864(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vmovaps 896(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 928(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 960(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 992(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm6, 0(%[dst_4])\n" + "vmovups %%ymm7, 32(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..6fcf7958 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..d0724917 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..8cec9259 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst1; + __m256 dst4; + __m256 dst2; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..b03ef6e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..db5d05d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst2; + __m256 dst5; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..893f3523 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..eaf7595f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst9; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst10; + __m256 dst2; + __m256 dst5; + __m256 dst8; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst9 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst10 = _mm256_load_ps(bias + 24); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst11 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst9 = _mm256_fmadd_ps(dst9, src00, weight30); + dst10 = _mm256_fmadd_ps(dst10, src10, weight30); + dst11 = _mm256_fmadd_ps(dst11, src20, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst9 = _mm256_fmadd_ps(dst9, src01, weight31); + dst10 = _mm256_fmadd_ps(dst10, src11, weight31); + dst11 = _mm256_fmadd_ps(dst11, src21, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst9 = _mm256_fmadd_ps(dst9, src02, weight32); + dst10 = _mm256_fmadd_ps(dst10, src12, weight32); + dst11 = _mm256_fmadd_ps(dst11, src22, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst9 = _mm256_fmadd_ps(dst9, src03, weight33); + dst10 = _mm256_fmadd_ps(dst10, src13, weight33); + dst11 = _mm256_fmadd_ps(dst11, src23, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst9 = _mm256_fmadd_ps(dst9, src04, weight34); + dst10 = _mm256_fmadd_ps(dst10, src14, weight34); + dst11 = _mm256_fmadd_ps(dst11, src24, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst9 = _mm256_fmadd_ps(dst9, src05, weight35); + dst10 = _mm256_fmadd_ps(dst10, src15, weight35); + dst11 = _mm256_fmadd_ps(dst11, src25, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst9 = _mm256_fmadd_ps(dst9, src06, weight36); + dst10 = _mm256_fmadd_ps(dst10, src16, weight36); + dst11 = _mm256_fmadd_ps(dst11, src26, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst9 = _mm256_fmadd_ps(dst9, src07, weight37); + dst10 = _mm256_fmadd_ps(dst10, src17, weight37); + dst11 = _mm256_fmadd_ps(dst11, src27, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); + _mm256_store_ps(dst + 3 * src_stride + 0, dst9); + _mm256_store_ps(dst + 3 * src_stride + 8, dst10); + _mm256_store_ps(dst + 3 * src_stride + 16, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..e4d2f5b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,301 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 0(%[dst_4]), %%ymm9\n" + "vmovups 32(%[dst_4]), %%ymm10\n" + "vmovups 64(%[dst_4]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 96(%[bias]), %%ymm9\n" + "vmovaps 96(%[bias]), %%ymm10\n" + "vmovaps 96(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag), [dst_4] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vmovaps 0(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 32(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 64(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 96(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vmovaps 128(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 160(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 192(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 224(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vmovaps 256(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 288(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 320(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 352(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vmovaps 384(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 416(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 448(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 480(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vmovaps 512(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 544(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 576(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 608(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vmovaps 640(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 672(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 704(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 736(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vmovaps 768(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 800(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 832(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 864(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vmovaps 896(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 928(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 960(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 992(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 0(%[dst_4])\n" + "vmovups %%ymm10, 32(%[dst_4])\n" + "vmovups %%ymm11, 64(%[dst_4])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t), + [dst_4] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..aff74f6f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..dae2ecbb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..4b34d19b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst1; + __m256 dst5; + __m256 dst2; + __m256 dst6; + __m256 dst3; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..976b074a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,235 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..8d987c61 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst8; + __m256 dst1; + __m256 dst5; + __m256 dst9; + __m256 dst2; + __m256 dst6; + __m256 dst10; + __m256 dst3; + __m256 dst7; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst9 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst10 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst11 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + dst8 = _mm256_fmadd_ps(dst8, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + dst9 = _mm256_fmadd_ps(dst9, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + dst10 = _mm256_fmadd_ps(dst10, src20, weight20); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + dst11 = _mm256_fmadd_ps(dst11, src30, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + dst8 = _mm256_fmadd_ps(dst8, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + dst9 = _mm256_fmadd_ps(dst9, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + dst10 = _mm256_fmadd_ps(dst10, src21, weight21); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + dst11 = _mm256_fmadd_ps(dst11, src31, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + dst8 = _mm256_fmadd_ps(dst8, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + dst9 = _mm256_fmadd_ps(dst9, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + dst10 = _mm256_fmadd_ps(dst10, src22, weight22); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + dst11 = _mm256_fmadd_ps(dst11, src32, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + dst8 = _mm256_fmadd_ps(dst8, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + dst9 = _mm256_fmadd_ps(dst9, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + dst10 = _mm256_fmadd_ps(dst10, src23, weight23); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + dst11 = _mm256_fmadd_ps(dst11, src33, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + dst8 = _mm256_fmadd_ps(dst8, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + dst9 = _mm256_fmadd_ps(dst9, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + dst10 = _mm256_fmadd_ps(dst10, src24, weight24); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + dst11 = _mm256_fmadd_ps(dst11, src34, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + dst8 = _mm256_fmadd_ps(dst8, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + dst9 = _mm256_fmadd_ps(dst9, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + dst10 = _mm256_fmadd_ps(dst10, src25, weight25); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + dst11 = _mm256_fmadd_ps(dst11, src35, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + dst8 = _mm256_fmadd_ps(dst8, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + dst9 = _mm256_fmadd_ps(dst9, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + dst10 = _mm256_fmadd_ps(dst10, src26, weight26); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + dst11 = _mm256_fmadd_ps(dst11, src36, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + dst8 = _mm256_fmadd_ps(dst8, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + dst9 = _mm256_fmadd_ps(dst9, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + dst10 = _mm256_fmadd_ps(dst10, src27, weight27); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + dst11 = _mm256_fmadd_ps(dst11, src37, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); + _mm256_store_ps(dst + 2 * src_stride + 0, dst8); + _mm256_store_ps(dst + 2 * src_stride + 8, dst9); + _mm256_store_ps(dst + 2 * src_stride + 16, dst10); + _mm256_store_ps(dst + 2 * src_stride + 24, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..3470b7f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n" + "vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 64(%[bias]), %%ymm9\n" + "vmovaps 64(%[bias]), %%ymm10\n" + "vmovaps 64(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..c972e8a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..9377253b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,171 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..5426fb62 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst5; + __m256 dst1; + __m256 dst6; + __m256 dst2; + __m256 dst7; + __m256 dst3; + __m256 dst8; + __m256 dst4; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst5 = _mm256_fmadd_ps(dst5, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst6 = _mm256_fmadd_ps(dst6, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst7 = _mm256_fmadd_ps(dst7, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst8 = _mm256_fmadd_ps(dst8, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst9 = _mm256_fmadd_ps(dst9, src40, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst5 = _mm256_fmadd_ps(dst5, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst6 = _mm256_fmadd_ps(dst6, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst7 = _mm256_fmadd_ps(dst7, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst8 = _mm256_fmadd_ps(dst8, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst9 = _mm256_fmadd_ps(dst9, src41, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst5 = _mm256_fmadd_ps(dst5, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst6 = _mm256_fmadd_ps(dst6, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst7 = _mm256_fmadd_ps(dst7, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst8 = _mm256_fmadd_ps(dst8, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst9 = _mm256_fmadd_ps(dst9, src42, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst5 = _mm256_fmadd_ps(dst5, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst6 = _mm256_fmadd_ps(dst6, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst7 = _mm256_fmadd_ps(dst7, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst8 = _mm256_fmadd_ps(dst8, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst9 = _mm256_fmadd_ps(dst9, src43, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst5 = _mm256_fmadd_ps(dst5, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst6 = _mm256_fmadd_ps(dst6, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst7 = _mm256_fmadd_ps(dst7, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst8 = _mm256_fmadd_ps(dst8, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst9 = _mm256_fmadd_ps(dst9, src44, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst5 = _mm256_fmadd_ps(dst5, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst6 = _mm256_fmadd_ps(dst6, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst7 = _mm256_fmadd_ps(dst7, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst8 = _mm256_fmadd_ps(dst8, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst9 = _mm256_fmadd_ps(dst9, src45, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst5 = _mm256_fmadd_ps(dst5, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst6 = _mm256_fmadd_ps(dst6, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst7 = _mm256_fmadd_ps(dst7, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst8 = _mm256_fmadd_ps(dst8, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst9 = _mm256_fmadd_ps(dst9, src46, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst5 = _mm256_fmadd_ps(dst5, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst6 = _mm256_fmadd_ps(dst6, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst7 = _mm256_fmadd_ps(dst7, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst8 = _mm256_fmadd_ps(dst8, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst9 = _mm256_fmadd_ps(dst9, src47, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 1 * src_stride + 0, dst5); + _mm256_store_ps(dst + 1 * src_stride + 8, dst6); + _mm256_store_ps(dst + 1 * src_stride + 16, dst7); + _mm256_store_ps(dst + 1 * src_stride + 24, dst8); + _mm256_store_ps(dst + 1 * src_stride + 32, dst9); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..ba495499 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,271 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..c3a6e1a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..2cc835a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..5e074d19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,305 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst6; + __m256 dst1; + __m256 dst7; + __m256 dst2; + __m256 dst8; + __m256 dst3; + __m256 dst9; + __m256 dst4; + __m256 dst10; + __m256 dst5; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst6 = _mm256_fmadd_ps(dst6, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst7 = _mm256_fmadd_ps(dst7, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst8 = _mm256_fmadd_ps(dst8, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst9 = _mm256_fmadd_ps(dst9, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst10 = _mm256_fmadd_ps(dst10, src40, weight10); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + dst11 = _mm256_fmadd_ps(dst11, src50, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst6 = _mm256_fmadd_ps(dst6, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst7 = _mm256_fmadd_ps(dst7, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst8 = _mm256_fmadd_ps(dst8, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst9 = _mm256_fmadd_ps(dst9, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst10 = _mm256_fmadd_ps(dst10, src41, weight11); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + dst11 = _mm256_fmadd_ps(dst11, src51, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst6 = _mm256_fmadd_ps(dst6, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst7 = _mm256_fmadd_ps(dst7, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst8 = _mm256_fmadd_ps(dst8, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst9 = _mm256_fmadd_ps(dst9, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst10 = _mm256_fmadd_ps(dst10, src42, weight12); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + dst11 = _mm256_fmadd_ps(dst11, src52, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst6 = _mm256_fmadd_ps(dst6, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst7 = _mm256_fmadd_ps(dst7, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst8 = _mm256_fmadd_ps(dst8, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst9 = _mm256_fmadd_ps(dst9, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst10 = _mm256_fmadd_ps(dst10, src43, weight13); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + dst11 = _mm256_fmadd_ps(dst11, src53, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst6 = _mm256_fmadd_ps(dst6, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst7 = _mm256_fmadd_ps(dst7, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst8 = _mm256_fmadd_ps(dst8, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst9 = _mm256_fmadd_ps(dst9, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst10 = _mm256_fmadd_ps(dst10, src44, weight14); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + dst11 = _mm256_fmadd_ps(dst11, src54, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst6 = _mm256_fmadd_ps(dst6, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst7 = _mm256_fmadd_ps(dst7, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst8 = _mm256_fmadd_ps(dst8, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst9 = _mm256_fmadd_ps(dst9, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst10 = _mm256_fmadd_ps(dst10, src45, weight15); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + dst11 = _mm256_fmadd_ps(dst11, src55, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst6 = _mm256_fmadd_ps(dst6, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst7 = _mm256_fmadd_ps(dst7, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst8 = _mm256_fmadd_ps(dst8, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst9 = _mm256_fmadd_ps(dst9, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst10 = _mm256_fmadd_ps(dst10, src46, weight16); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + dst11 = _mm256_fmadd_ps(dst11, src56, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst6 = _mm256_fmadd_ps(dst6, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst7 = _mm256_fmadd_ps(dst7, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst8 = _mm256_fmadd_ps(dst8, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst9 = _mm256_fmadd_ps(dst9, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst10 = _mm256_fmadd_ps(dst10, src47, weight17); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + dst11 = _mm256_fmadd_ps(dst11, src57, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 1 * src_stride + 0, dst6); + _mm256_store_ps(dst + 1 * src_stride + 8, dst7); + _mm256_store_ps(dst + 1 * src_stride + 16, dst8); + _mm256_store_ps(dst + 1 * src_stride + 24, dst9); + _mm256_store_ps(dst + 1 * src_stride + 32, dst10); + _mm256_store_ps(dst + 1 * src_stride + 40, dst11); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..ba927555 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,307 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n" + "vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "vmovaps 32(%[bias]), %%ymm10\n" + "vmovaps 32(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..7efdcd8d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..67457e0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,215 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..e99e39da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..1923c152 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..e11e2af8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..c391edc0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000..3dc30cff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000..3f25c287 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [dst] "r"(dst), [bias] "r"(bias), [dst_stride] "r"(dst_stride_t), [inc_flag] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + : + : [src] "r"(src), [src_stride] "r"(src_stride_t), [weight] "r"(weight), [deep] "r"(deep_t), + [inc_flag] "r"(inc_flag), [act_flag] "r"(act_flag), [dst] "r"(dst), [dst_stride] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..8cb78d4c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,536 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..7de18d34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,784 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..90642d55 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,577 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", + "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..8988e02c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,847 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..1970c957 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,617 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..c22649ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,911 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d1174d80 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d4ecdf14 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..149bb2d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..4e944705 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,281 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..885fad16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..12eecaf6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,361 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..7e4d87e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..be5f4408 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,264 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..996d1af1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..63d538b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,390 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d18a9654 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,453 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..ddc82b0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,517 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..bb3bc480 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..dc26494b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..e0c94133 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,413 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..7dce3d6b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,500 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d039b5b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,586 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..11e7ae25 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,672 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..211edb7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,286 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..697b6c9c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,395 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..534ecd34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,505 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..3dd8b7e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..59e3ac4e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..6bdd5140 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..9306e9d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,326 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..c1aea19b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,458 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..dd6e4627 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,591 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..fa03a12d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..b2e60fa1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,856 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..963d83d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,366 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..db0e244d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,522 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..38dfe4c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,677 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..3dd0b65e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..67f76c0b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..16c5ed1e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,589 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..928e80c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,767 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..644154c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..bc15930e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,652 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..d3c08cf0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,854 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..76431df8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,490 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 00000000..4b3e39b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,715 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh new file mode 100644 index 00000000..cd2fbf51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generate_hpc.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +CRTDIR=$( cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# generate gemm fma asm code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c +# +## generate gemm fma intrinics code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c + +# generate gemm avx512 asm code +n=(96 80 64 48 32 16) +m=(4 5 6 8 12 12) +for ((index = 0; index < 6; index++)) +do + for ((row = 1; row <= ${m[index]}; row++)) + do + dst_file=$CRTDIR"/gemm_avx512/nnacl_gemm_avx512_$row""x${n[index]}_kernel_nhwc_fp32.c" + python3 $CRTDIR/generator.py -I $CRTDIR/template_file/gemm_avx512_nhwc_asm.c.in -A row_block=$row col_block=${n[index]} -O $dst_file + done +done diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py new file mode 100644 index 00000000..89342097 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/generator.py @@ -0,0 +1,162 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""HPC generator""" + +import sys +import os +import io +import stat +import argparse +from itertools import chain + +def key_value_pair(line): + """ + split key and value + :param line: + :return: + """ + key = None + value = None + try: + key, value = line.split("=", 1) + except ValueError: + print("line must be format: key=value, but now is:", line) + sys.exit(1) + try: + value = int(value) + except ValueError: + print("Error: you input value must be integer, but now is:", value) + sys.exit(1) + return key, value + +def get_indent(line): + """ + get indent length + :param line: + :return: + """ + index = 0 + for i in line: + if i == " ": + index += 1 + else: + break + return index + +def print_line(line): + """ + Convert line to a python string + :param line: + :return: + """ + global PYTHON_INDENT + global GENERATE_CODE_INDENT + if line.strip()[0] == "}" or line.strip()[0] == ")": + PYTHON_INDENT = -1 + split_str = line.split("@") + if line.strip()[0] != "@" and len(split_str) == 1: + if get_indent(line) == PYTHON_INDENT or PYTHON_INDENT == -1: + result = ["print(", line, ", file=OUT_STREAM)"] + PYTHON_INDENT = -1 + if "{" in line or "asm volatile(" in line: + GENERATE_CODE_INDENT = get_indent(line) + if line.strip().startswith("}") and "{" not in line: + GENERATE_CODE_INDENT -= 4 + if len(line) == 1 and line[0] == "}": + # modify next fun GENERATE_CODE_INDENT + GENERATE_CODE_INDENT = -4 + return "\"".join(result) + + if line.strip()[0] == '@': + # get python indent and first GENERATE_CODE_INDENT + if PYTHON_INDENT == -1: + GENERATE_CODE_INDENT = get_indent(line) - 4 + PYTHON_INDENT = get_indent(line) + result = split_str[0][PYTHON_INDENT:] + split_str[1] + return result + + index = get_indent(split_str[0]) + result = [split_str[0][PYTHON_INDENT:index] + "print("] + prefix = " " * (GENERATE_CODE_INDENT + 4) + split_str[0].lstrip() + + suffix = " %(" + for str_tmp in split_str[1:]: + second = str_tmp.find("}") + suffix += str_tmp[1:second] + ', ' + str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d") + prefix += str_tmp + result.append(prefix) + result.append(suffix + "), file=OUT_STREAM)") + return "\"".join(result) + +def generate_code(template_file, exec_dict): + """ + generate hpc + :param template_file: template file path + :param exec_dict: dict + :return: hpc + """ + output_stream = io.StringIO() + with open(template_file, 'r') as f: + generate_code_lines = [] + for line in f: + line = line.replace("\n", "") + if line.strip() and line.strip()[0] != "@": + line = line.replace("\"", "\\\"") + line = line.replace("%", "%%") + if "print" in line: + line = line.replace("%%", "%") + if not line: + generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)") + else: + str = print_line(line) + if "%(" not in str: + str = str.replace("%%[", "%[") + generate_code_lines.append(str) + c = compile('\n'.join(generate_code_lines), '', 'exec') + exec_dict["OUT_STREAM"] = output_stream + exec(c, exec_dict) + return output_stream.getvalue() + +def check_python_version(): + if sys.version_info < (3, 6): + sys.stdout.write("At least python 3.6 is required, but now is " + str(sys.version_info.major) + "." + + str(sys.version_info.minor) + "\n") + sys.exit(1) + +GENERATE_CODE_INDENT = -4 +PYTHON_INDENT = -1 + +parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator") +parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code") +parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append", + help="Custom Parameters") +parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path") + +if __name__ == "__main__": + check_python_version() + parameters = parser.parse_args(sys.argv[1:]) + exec_globals = dict(chain(*parameters.defines)) + + generate_code_str = generate_code(parameters.Template_File[0], exec_globals) + if os.path.exists(parameters.Output_File[0]): + os.remove(parameters.Output_File[0]) + + saveDir = os.path.dirname(parameters.Output_File[0]) + if not os.path.exists(saveDir): + os.mkdir(saveDir, 0o700) + with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file: + output_file.write(generate_code_str) + os.chmod(parameters.Output_File[0], stat.S_IWUSR + stat.S_IRUSR) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in new file mode 100644 index 00000000..d60110df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in @@ -0,0 +1,263 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, const u_int16_t* mask) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "kmovw (%[mask]), %k1\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", \"%k1\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %k1\\n" + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %%xmm@{30}, %%zmm30\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}]) %{{%%k1}}\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}) %{{%%k1}}\\n" + @else: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in new file mode 100644 index 00000000..335ed7f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in @@ -0,0 +1,231 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %xmm30, %zmm30\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in new file mode 100644 index 00000000..641b1857 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + __m256 dst@{j * row_block + i}; + if (inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8}); + } else if (bias == NULL) { + @for i in range(0, row_block * col_block >> 3): + dst@{i} = _mm256_setzero_ps(); + } else { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8}); + } + for (int i = 0; i < (deep >> 3); ++i) { + @for i in range(0, 8): + // bock@{i} + @if col_block == 32: + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + @else: + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + src = src + src_stride; + weight += @{8 * col_block * 4}; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6); + // relu + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + @if col_block == 32: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); + @else: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in new file mode 100644 index 00000000..70178cf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @if col_block == 32: + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\\n" + "je 0f\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n" + @else: + "vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n" + "jmp 2f\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n" + "jmp 2f\\n" + "1:\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n" + "2:\\n" + : + @list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM) + ); + asm volatile( + "0:\\n" + @for i in range(0, 8): + // block @{i} + @if col_block == 32: + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n" + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n" + @for row in range(0, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n" + @elif col_block == 24: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @elif col_block == 16: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block >> 1): + "vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n" + "vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, 2): + "vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(row_block >> 1 << 1, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @else: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @split_num = 3 + @for row in range(0, int(row_block / split_num)): + @for j in range(0, split_num): + "vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, split_num): + "vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(int(row_block / split_num) * split_num, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n" + @for col in range(0, col_block >> 3): + @for row in range(int(row_block / split_num) * split_num, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n" + "dec %[deep]\\n" + "add @{col_block * 4 * 8}, %[weight]\\n" + "add %[src_stride], %[src]\\n" + "jg 0b\\n" + + "movq %[inc_flag], %rax\\n" + "and $0x2, %eax\\n" + "je 3f\\n" + "movq %[act_flag], %rax\\n" + "and $0x3, %eax\\n" + "je 3f\\n" + // relu + "vxorps %ymm15, %ymm15, %ymm15\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n" + "and $0x1, %eax\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm14\\n" + "vpermps %ymm14, %ymm15, %ymm14\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n" + "3:\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n" + @else: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n" + : + @list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h new file mode 100644 index 00000000..433bd4bb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fill_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FILL_PARAMETER_H_ +#define NNACL_FILL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct FillParameter { + OpParameter op_parameter_; +} FillParameter; + +#endif // NNACL_FILL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h new file mode 100644 index 00000000..74049d69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/flatten_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FLATTEN_PARAMETER_H_ +#define NNACL_FLATTEN_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct FlattenParameter { + OpParameter op_parameter_; + int axis_; +} FlattenParameter; + +#endif // NNACL_FLATTEN_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h new file mode 100644 index 00000000..0015c178 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/format_transpose_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ +#define NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/common_infer.h" +static const int FormatTransposeInput = 2; +typedef struct FormatTransposeParameter { + // Primitive parameter + OpParameter op_parameter_; + FormatC src_format_; + FormatC dst_format_; +} FormatTransposeParameter; + +#endif // NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c new file mode 100644 index 00000000..d03d2798 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.c @@ -0,0 +1,319 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/activation_fp16.h" +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp16/exp_fp16.h" +#include "nnacl_c/errorcode.h" + +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero = vdupq_n_f16(0); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_value = vld1q_f16(src + offset); + float16x8_t rst_value = vmaxq_f16(src_value, zero); + vst1q_f16(dst + offset, rst_value); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; + } + return NNACL_OK; +} + +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t relu6_data = vld1q_f16(data + offset); + relu6_data = vmaxq_f16(relu6_data, zero_data); + relu6_data = vminq_f16(relu6_data, six_data); + vst1q_f16(dst + offset, relu6_data); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; + dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; + } + return NNACL_OK; +} + +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t alpha_data = vdupq_n_f16(alpha); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_tmp = vld1q_f16(src + offset); + float16x8_t mul_tmp = vmulq_f16(src_tmp, alpha_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(dst + offset, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; offset < ele_num; ++offset) { + dst[offset] = src[offset] > (float16_t)0.0f ? src[offset] : (src[offset] * alpha); + } + return NNACL_OK; +} + +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t tmp; + simd_exp128(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); + vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); + } +#endif + for (; i < ele_num; ++i) { + float temp; + simd_exp32(-src[i], &temp); + dst[i] = (float16_t)1.0f / ((float16_t)1.0f + temp); + } + return NNACL_OK; +} + +float16_t TanhOptFp16(float16_t src) { + if (src > 5.0f) { + return 1.0f; + } else if (src < -5.0f) { + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, + {17325.0f, 17325.0f, 17325.0f, 17325.0f}, + {135135.0f, 135135.0f, 135135.0f, 135135.0f}, + {28.0f, 28.0f, 28.0f, 28.0f}, + {3150.0f, 3150.0f, 3150.0f, 3150.0f}, + {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; + float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; + float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); + float32x4_t square = vmulq_f32(input, input); + float32x4_t a = vmulq_f32( + vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), + input); + float32x4_t b = vaddq_f32( + vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), + paramv[2]); + vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); + } +#endif + for (; i < ele_num; ++i) { + float input = src[i]; + float square = input * input; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + dst[i] = a / b; + dst[i] = MSMAX(dst[i], -1); + dst[i] = MSMIN(dst[i], 1); + } + return NNACL_OK; +} + +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; i <= ele_num - C8NUM; i += C8NUM) { + MS_FLOAT16X8 in_data = MS_LDQ_F16(src + i); + MS_FLOAT16X8 tmp = MS_MAXQ_F16(in_data + three_data, zero_data); + tmp = MS_MINQ_F16(tmp, six_data); + MS_STQ_F16(dst + i, vmulq_f16(in_data, MS_DIVQ_F16(tmp, six_data))); + } +#endif + for (; i < ele_num; ++i) { + float16_t in = src[i]; + float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); + dst[i] = in * relu6 / (float16_t)6.0f; + } + return NNACL_OK; +} + +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + float32x4_t const_val = vdupq_n_f32(1.0f); + for (int num_max = ele_num - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t exp0 = simd_exp128_f32(vnegq_f32(in0)); + float32x4_t exp1 = simd_exp128_f32(vnegq_f32(in1)); + float32x4_t exp2 = simd_exp128_f32(vnegq_f32(in2)); + float32x4_t exp3 = simd_exp128_f32(vnegq_f32(in3)); + float32x4_t res0 = MS_DIVQ_F32(in0, vaddq_f32(const_val, exp0)); + float32x4_t res1 = MS_DIVQ_F32(in1, vaddq_f32(const_val, exp1)); + float32x4_t res2 = MS_DIVQ_F32(in2, vaddq_f32(const_val, exp2)); + float32x4_t res3 = MS_DIVQ_F32(in3, vaddq_f32(const_val, exp3)); + float16x4x4_t res = {MS_CVT_F16_F32(res0), MS_CVT_F16_F32(res1), MS_CVT_F16_F32(res2), MS_CVT_F16_F32(res3)}; + vst4_f16(dst + i, res); + } + for (int num_max = ele_num - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float16x4_t res = MS_CVT_F16_F32(MS_DIVQ_F32(in, vaddq_f32(const_val, simd_exp128_f32(vnegq_f32(in))))); + vst1_f16(dst + i, res); + } +#endif + for (; i < ele_num; ++i) { + float temp = simd_exp32_f32(-src[i]); + dst[i] = src[i] / (1.0f + temp); + } + return NNACL_OK; +} + +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + MS_FLOAT16X8 relu6_data = MS_LDQ_F16(src + offset) + three_data; + relu6_data = MS_MAXQ_F16(relu6_data, zero_data); + relu6_data = MS_MINQ_F16(relu6_data, six_data); + MS_STQ_F16(dst + offset, MS_DIVQ_F16(relu6_data, six_data)); + } +#endif + + for (; offset < ele_num; offset++) { + float16_t tmp = (src[offset] + 3.0 < 0.0) ? 0.0 : src[offset] + 3.0; + dst[offset] = ((tmp < 6.0) ? tmp : 6.0) / 6.0; + } + + return NNACL_OK; +} + +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) +#ifdef ENABLE_NEON + for (int num_max = length - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t res0 = 0.5f * in0 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in0 * in0) * in0)); + float32x4_t res1 = 0.5f * in1 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in1 * in1) * in1)); + float32x4_t res2 = 0.5f * in2 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in2 * in2) * in2)); + float32x4_t res3 = 0.5f * in3 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in3 * in3) * in3)); + float16x4x4_t res = { + MS_CVT_F16_F32(res0), + MS_CVT_F16_F32(res1), + MS_CVT_F16_F32(res2), + MS_CVT_F16_F32(res3), + }; + vst4_f16(dst + i, res); + } + for (int num_max = length - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float32x4_t res = 0.5f * in * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); + vst1_f16(dst + i, MS_CVT_F16_F32(res)); + } +#endif + for (; i < length; i++) { + dst[i] = + 0.5f * src[i] * + (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { +#ifdef ENABLE_NEON + int C8 = DOWN_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + float16x8_t in = vld1q_f16(src + i); + const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); + vst1q_f16(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); + } + } + return NNACL_OK; +} + +int SoftplusFp16(const float16_t *src, int length, float16_t *dst) { + int i = 0; + for (; i < length; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] = log1p(dst[i]); + } + return NNACL_OK; +} + +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one = MS_MOVQ_F16(1.0f); + for (; i <= length - 8; i += 8) { + float16x8_t src_tmp = MS_LDQ_F16(src + i); + float16x8_t exp_tmp = VexpFp16(src_tmp); // exp(x) + exp_tmp = MS_SUBQ_F16(exp_tmp, one); // exp(x) - 1 + float16x8_t elu_tmp = MS_MULQ_N_F16(exp_tmp, alpha); + uint16x8_t mask = vcleq_f16(src_tmp, MS_MOVQ_F16(0.0f)); + MS_STQ_F16(dst + i, vbslq_f16(mask, elu_tmp, src_tmp)); + } +#endif + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h new file mode 100644 index 00000000..2283cd0c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/activation_fp16.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ACTIVATION_FP16_H_ +#define NNACL_FP16_ACTIVATION_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num); +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num); +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha); +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val); +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate); +int SoftplusFp16(const float16_t *src, int length, float16_t *dst); +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_ACTIVATION_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c new file mode 100644 index 00000000..2b9ffdbe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.c @@ -0,0 +1,273 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/arg_min_max_fp16.h" + +int ArgCompareAscFp16(const void *a, const void *b) { + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + if (b_value > a_value) { + return -1; + } + if (b_value < a_value) { + return 1; + } + + return 0; +} + +int ArgCompareDescFp16(const void *a, const void *b) { + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + if (b_value > a_value) { + return 1; + } + if (b_value < a_value) { + return -1; + } + + return 0; +} + +void ArgMaxTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = -FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinMaxDim0Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[j].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[j].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[j].data_.f16_data_; + } + } + } + return; +} + +void ArgMinMaxDim1Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[k].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[k].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[k].data_.f16_data_; + } + } + } + } + return; +} + +void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param) { + if (param->topk_ == 1) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + + if (param->get_max_) { + ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } else { + ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } + return; + } + + COMPARE_FUNCTION compare_function = NULL; + if (param->get_max_) { + compare_function = ArgCompareDescFp16; + } else { + compare_function = ArgCompareAscFp16; + } + + switch (param->axis_) { + case 0: + ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 1: + ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 2: + ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 3: + ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function); + break; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h new file mode 100644 index 00000000..b54971f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arg_min_max_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARG_MIN_MAX_FP16_H_ +#define NNACL_FP16_ARG_MIN_MAX_FP16_H_ + +#include +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARG_MIN_MAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c new file mode 100644 index 00000000..01866fad --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.c @@ -0,0 +1,1314 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param) { + TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param); + return ElementAddFp16(tile_in0, tile_in1, out, size); +} + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple) { + const float16_t *inData = (const float16_t *)input; + float16_t *outData = (float16_t *)output; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float16_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[index]; + } + return NNACL_OK; +} + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] * input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[0]; + } + } + return NNACL_OK; +} + +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] * input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] * input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[index]; + } + return NNACL_OK; +} + +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] + input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[0]; + } + } + return NNACL_OK; +} + +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0f); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmax_f16(vout, zeros1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] + input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0); + float16x4_t bounds1 = vdup_n_f16(6.0); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmin_f16(vmax_f16(vout, zeros1), bounds1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] + input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[index]; + } + return NNACL_OK; +} + +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] - input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[0]; + } + } + return NNACL_OK; +} + +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] - input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] - input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[index]; + } + return NNACL_OK; +} + +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] / input1[index]; + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[0]; + } + } + return NNACL_OK; +} + +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float16_t res = input0[index] / input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMAX(input0[0] / input1[index], 0); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index] / input1[0], 0); + } + } + return NNACL_OK; +} + +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[index] / input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[0]) * input1[0]; + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + } + return NNACL_OK; +} + +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + return NNACL_OK; +} +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[0]); + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + } + return NNACL_OK; +} + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) & (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) | (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size) { + ElementSubFp16(input0, input1, output, element_size); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar) { + ElementOptSubFp16(input0, input1, output, element_size, first_scalar); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[index]; + } + return NNACL_OK; +} + +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] != input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[0]; + } + } + return NNACL_OK; +} + +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[index]; + } + return NNACL_OK; +} + +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] == input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] < input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] <= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] > input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] >= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[0]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h new file mode 100644 index 00000000..0bd3ece9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_fp16.h @@ -0,0 +1,124 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARITHMETIC_FP16_H_ +#define NNACL_FP16_ARITHMETIC_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param); + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar); +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param); + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c new file mode 100644 index 00000000..eb693c62 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp16/arithmetic_self_fp16.h" + +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h new file mode 100644 index 00000000..995abffd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/arithmetic_self_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARITHMETIC_SELF_FP16_H_ +#define NNACL_FP16_ARITHMETIC_SELF_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number); + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_SELF_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c new file mode 100644 index 00000000..d34d98e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.c @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/batchnorm_fp16.h" +#include +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 output_8 = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 scale_8 = MS_LDQ_F16(scale + c); + MS_FLOAT16X8 offset_8 = MS_LDQ_F16(offset + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 norm_val = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_FLOAT16X8 output_8 = MS_ADDQ_F16(MS_MULQ_F16(norm_val, scale_8), offset_8); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + float16_t norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f; + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= (float16_t)N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (float16_t)((float)(input[idx] - run_mean[c]) * (float)(input[idx] - run_mean[c])); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = ((float)run_var[c] / VNUB); + run_var[c] = (float16_t)((float)run_var[c] / VN); + save_mean[c] = (float16_t)(momentum * (float)save_mean[c] + (1.0f - momentum) * (float)run_mean[c]); + save_var[c] = (float16_t)(momentum * (float)save_var[c] + (1.0f - momentum) * unbiased_var); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h new file mode 100644 index 00000000..b74a083d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/batchnorm_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_BATCHNORM_FP16_H_ +#define NNACL_FP16_BATCHNORM_FP16_H_ + +#include "nnacl_c/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output); +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output); +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_BATCHNORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h new file mode 100644 index 00000000..6c556a5d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/cast_fp16.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CAST_FP16_H_ +#define NNACL_FP16_CAST_FP16_H_ + +#include "nnacl_c/op_base.h" +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) +#include + +#ifdef __cplusplus +extern "C" { +#endif + +inline void BoolToFloat16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToInt32(const float16_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +inline void Float16ToInt64(const float16_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +#ifdef ENABLE_ARM64 +inline void Float32ToFloat16(const float *__restrict input, float16_t *__restrict output, int number) { + int count = (number & ~(C8NUM - 1)); + int i = 0; + for (; i < count; i += C8NUM) { + float32x4_t in1 = vld1q_f32(input + i); + float16x4_t out1 = vcvt_f16_f32(in1); + float32x4_t in2 = vld1q_f32(input + i + 4); + float16x4_t out2 = vcvt_f16_f32(in2); + float16x8_t out = vcombine_f16(out1, out2); + vst1q_f16(output + i, out); + } + for (; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToFloat32(const float16_t *__restrict input, float *__restrict output, int number) { + int count = number & ~(C8NUM - 1); + int i = 0; + for (; i < count; i += C8NUM) { + float16x8_t in = vld1q_f16(input + i); + float16x4_t in1 = vget_low_f16(in); + float16x4_t in2 = vget_high_f16(in); + float32x4_t out1 = vcvt_f32_f16(in1); + vst1q_f32(output + i, out1); + float32x4_t out2 = vcvt_f32_f16(in2); + vst1q_f32(output + i + C4NUM, out2); + } + for (; i < number; ++i) { + output[i] = (float)input[i]; + } +} +#else +void Float32ToFloat16(const float *input, float16_t *output, int number); + +void Float16ToFloat32(const float16_t *input, float *output, int number); +#endif + +#ifdef __cplusplus +} +#endif +#endif +#endif // NNACL_FP16_CAST_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c new file mode 100644 index 00000000..1cce761d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.c @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/common_func_fp16.h" + +void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t oc_stride, size_t hw_stride, + ActType act_type, int size) { + if (size == 0) { + return; + } + for (int oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size, oc_mod = oc % size; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc_div * size * hw_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float16_t value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (act_type == ActType_Relu || act_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (act_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } + return; +} + +void PostConvFuncFp16C8(const float16_t *c8_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t oc_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc8mod = oc % C8NUM; + size_t oc8div = oc - oc8mod; + size_t stride_size = oc_stride * sizeof(float16_t); + PostFuncBiasReluC8Fp16(nhwc_out, c8_out, bias, oc8div, oc8mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c8_out, bias, oc, plane, oc_stride, plane, act_type, C8NUM); +#endif +} + +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t plane_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc4mod = oc % C4NUM; + size_t oc4div = oc - oc4mod; + size_t stride_size = (plane_stride - plane) * C4NUM * sizeof(float16_t); + PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c4_out, bias, oc, plane, oc, plane_stride, act_type, C4NUM); +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h new file mode 100644 index 00000000..95be975c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/common_func_fp16.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_COMMON_FUNC_FP16_H_ +#define NNACL_FP16_COMMON_FUNC_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* deconv common */ +void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, ActType act_type); +void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +/* deconv winograd */ +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t output_channel, + size_t plane_size, size_t plane_stride, ActType act_type); +void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_COMMON_FUNC_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h new file mode 100644 index 00000000..41719d6a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/constant_of_shape_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ +#define NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __cplusplus +#ifdef ENABLE_FP16 +inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} +#endif +} +#endif + +#endif // NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c new file mode 100644 index 00000000..7d718b87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.c @@ -0,0 +1,842 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/conv_depthwise_fp16.h" +#include +#include "nnacl_c/fp16/activation_fp16.h" + +#ifdef ENABLE_ARM82_A32 +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +#ifdef ENABLE_ARM +static void ConvDw3x3RowLeftFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + v0 = MS_MOVQ_F16((float16_t)0.0); + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v1 = MS_LDQ_F16(src + ic); + v2 = MS_LDQ_F16(src + channel + ic); + v3 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d1 = src[i + ic]; + float16_t d2 = src[i + ic + channel]; + float16_t d3 = src[i + ic + 2 * channel]; + remain_line[i] = (float16_t)0.0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + v3 = MS_LDQ_F16(src + 3 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + float16_t d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRightFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3RowSingleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_STQ_F16(line + lw * ic, v0); + MS_STQ_F16(line + lw * ic + 8, v1); + MS_STQ_F16(line + lw * ic + 16, b2); + memset(line + lw * ic + 24, 0, 16); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 8] = d1; + remain_line[i + 16] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3InitTopFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3InitRowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeftFp16(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3RowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3BottomFp16(float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c8 * C4NUM * sizeof(float16_t)); +} + +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6) { + int channel = ori_channel; + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + for (; channel > 0; channel -= 8) { + MS_FLOAT16X8 bias = MS_LDQ_F16(bias_data); + bias_data += 8; + MS_FLOAT16X8 g00 = MS_LDQ_F16(weight); + MS_FLOAT16X8 g01 = MS_LDQ_F16(weight + 8); + MS_FLOAT16X8 g02 = MS_LDQ_F16(weight + 16); + MS_FLOAT16X8 g03 = MS_LDQ_F16(weight + 24); + MS_FLOAT16X8 g10 = MS_LDQ_F16(weight + 32); + MS_FLOAT16X8 g11 = MS_LDQ_F16(weight + 40); + MS_FLOAT16X8 g12 = MS_LDQ_F16(weight + 48); + MS_FLOAT16X8 g13 = MS_LDQ_F16(weight + 56); + MS_FLOAT16X8 g20 = MS_LDQ_F16(weight + 64); + MS_FLOAT16X8 g21 = MS_LDQ_F16(weight + 72); + MS_FLOAT16X8 g22 = MS_LDQ_F16(weight + 80); + MS_FLOAT16X8 g23 = MS_LDQ_F16(weight + 88); + weight += 96; + float16_t *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + MS_FLOAT16X8 acc3 = MS_MULQ_F16(MS_LDQ_F16(line0 + 24), g03); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line1 + 24), g13); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line2 + 24), g23); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + MS_FLOAT16X8 res1 = MS_ADDQ_F16(acc1, MS_SUBQ_F16(acc3, acc2)); + res0 = MS_ADDQ_F16(res0, bias); + res1 = MS_ADDQ_F16(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + res1 = MS_MAXQ_F16(res1, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + res1 = MS_MINQ_F16(res1, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + MS_STQ_F16(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + cur_dst[ori_channel + i] = res1[i]; + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + res0 = MS_ADDQ_F16(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + } + } + } + dst += 8; + } +} + +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c8 = UP_ROUND(conv_param->input_channel_, C8NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float16_t *line0 = buffer; + float16_t *line1 = buffer + units * c8 * C4NUM; + float16_t *line2 = buffer + units * c8 * C4NUM * 2; + float16_t *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTopFp16(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRowFp16(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, + conv_param->input_channel_, relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3BottomFp16(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} + +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->stride_w_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float16_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float16_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float16_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float16_t *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + ReluFp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + if (relu6) { + Relu6Fp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + } + } +} + +/*conv depthwise fp16 begin*/ +void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, + bool is_relu6) { + for (int c = 0; c < C8NUM; c++) { + dst[c] = 0; + } + const float16_t *src_kh = src; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst, dst_8); + + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C8NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float16_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float16_t *src_h = src + ih * sliding->in_h_step_; + + float16_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float16_t *src_w = src_h + iw * sliding->block_channel_; + + const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; +#ifdef ENABLE_ARM64 + ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6); +#else + DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, + int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float16_t *src_kh = src_w; + const float16_t *weight_kh = weight; + for (int c = 0; c < C8NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_w); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_w, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp16: sliding window +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DepthwiseBorderFp16(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + ConvDwFp16Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t), relu, relu6); +#else + DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, + sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6); +#endif + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nchwc8 +} +/*conv depthwise fp16 end*/ + +/*deconv depthwise fp16 begin*/ +void DeconvDepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w_step) { + float16_t *dst_kh = dst; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); + + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int top, int bottom, + int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const float16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float16_t *dst_h = dst + oh * sliding->in_h_step_; + + const float16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float16_t *dst_w = dst_h + ow * sliding->block_channel_; + + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + float16_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t)); +#else + DeconvDepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float16_t *dst_kh = dst_w; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_NEON + float16x8_t src_8 = vld1q_f16(src_w); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel, + const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + int hw = conv_param->output_h_ * conv_param->output_w_; + int hw8 = hw / C8NUM * C8NUM; + float16x8_t bias_value = vld1q_f16(bias); + float16x8_t zero = vdupq_n_f16(0.0f); + float16x8_t six = vdupq_n_f16(6.0f); + + int i = 0; + for (; i < hw8; i += C8NUM) { + float16_t *dst_ptr = dst + i * block_channel; + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + float16x8_t dst_value1 = vld1q_f16(dst_ptr + C1NUM * block_channel); + float16x8_t dst_value2 = vld1q_f16(dst_ptr + C2NUM * block_channel); + float16x8_t dst_value3 = vld1q_f16(dst_ptr + C3NUM * block_channel); + float16x8_t dst_value4 = vld1q_f16(dst_ptr + C4NUM * block_channel); + float16x8_t dst_value5 = vld1q_f16(dst_ptr + C5NUM * block_channel); + float16x8_t dst_value6 = vld1q_f16(dst_ptr + C6NUM * block_channel); + float16x8_t dst_value7 = vld1q_f16(dst_ptr + C7NUM * block_channel); + + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value1 = vaddq_f16(dst_value1, bias_value); + dst_value2 = vaddq_f16(dst_value2, bias_value); + dst_value3 = vaddq_f16(dst_value3, bias_value); + dst_value4 = vaddq_f16(dst_value4, bias_value); + dst_value5 = vaddq_f16(dst_value5, bias_value); + dst_value6 = vaddq_f16(dst_value6, bias_value); + dst_value7 = vaddq_f16(dst_value7, bias_value); + if (relu) { + dst_value0 = vmaxq_f16(dst_value0, zero); + dst_value1 = vmaxq_f16(dst_value1, zero); + dst_value2 = vmaxq_f16(dst_value2, zero); + dst_value3 = vmaxq_f16(dst_value3, zero); + dst_value4 = vmaxq_f16(dst_value4, zero); + dst_value5 = vmaxq_f16(dst_value5, zero); + dst_value6 = vmaxq_f16(dst_value6, zero); + dst_value7 = vmaxq_f16(dst_value7, zero); + } + if (relu6) { + dst_value0 = vminq_f16(dst_value0, six); + dst_value1 = vminq_f16(dst_value1, six); + dst_value2 = vminq_f16(dst_value2, six); + dst_value3 = vminq_f16(dst_value3, six); + dst_value4 = vminq_f16(dst_value4, six); + dst_value5 = vminq_f16(dst_value5, six); + dst_value6 = vminq_f16(dst_value6, six); + dst_value7 = vminq_f16(dst_value7, six); + } + vst1q_f16(dst_ptr, dst_value0); + vst1q_f16(dst_ptr + C1NUM * block_channel, dst_value1); + vst1q_f16(dst_ptr + C2NUM * block_channel, dst_value2); + vst1q_f16(dst_ptr + C3NUM * block_channel, dst_value3); + vst1q_f16(dst_ptr + C4NUM * block_channel, dst_value4); + vst1q_f16(dst_ptr + C5NUM * block_channel, dst_value5); + vst1q_f16(dst_ptr + C6NUM * block_channel, dst_value6); + vst1q_f16(dst_ptr + C7NUM * block_channel, dst_value7); + } + + float16_t *dst_ptr = dst + i * block_channel; + for (; i < hw; i++, dst_ptr += block_channel) { + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value0 = relu ? vmaxq_f16(dst_value0, zero) : dst_value0; + dst_value0 = relu6 ? vminq_f16(dst_value0, six) : dst_value0; + vst1q_f16(dst_ptr, dst_value0); + } +} + +// deconv depthwise fp16: sliding window +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t)); +#else + DeconvDepthwiseCenterFp16(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFuncFp16(dst_data, bias, sliding->block_channel_, conv_param); + } // output C8 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nchwc8 +} +/*deconv depthwise fp16 end*/ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h new file mode 100644 index 00000000..36d273d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_depthwise_fp16.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CONV_DEPTHWISE_FP16_H_ +#define NNACL_FP16_CONV_DEPTHWISE_FP16_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, + size_t input_channel, size_t input_step); +#ifdef ENABLE_ARM64 +void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, + size_t relu6); +void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, + size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, + size_t relu, size_t relu6); +void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); +void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef ENABLE_ARM +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6); +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_DEPTHWISE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c new file mode 100644 index 00000000..515186cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.c @@ -0,0 +1,334 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/conv_fp16.h" +#include +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/winograd_transform_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (dilation_h == 1 && dilation_w == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float16_t)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int n = kw_s; n < kw_e; n++) { + int input_x_stride = input_y_stride + n * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + n) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float16_t)); + } // kernel_w loop + } // kernel_h loop + } + } // tile num loop +} + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * tile_n; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * tile_n; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * tile_n; + col_major_input += task_id * deep * tile_n; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_; + for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, tile_n); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, + real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc); + } + } +} + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->op_parameter_.thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int input_block = UP_DIV(output_hw, tile_n); + int block_per_thread = UP_DIV(input_block, conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int end_block = MSMIN(start_block + block_per_thread, input_block); + if (start_block >= end_block) { + return; + } + int weight_block = UP_DIV(conv_param->output_channel_, C8NUM); + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += deep * tile_n * task_id; + col_major_input += deep * tile_n * task_id; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + for (int i = start_block; i < end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : output_hw - i * tile_n; + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_in_row, i * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + const float16_t *cur_weight = packed_weight; + const float16_t *cur_bias = bias_data; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * deep, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : conv_param->output_channel_ - j * C8NUM; + int out_offset = j * output_hw * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(col_major_input, cur_weight, output_data + out_offset, cur_bias, conv_param->act_type_, deep, + real_in_row, real_weight_row, real_weight_row, OutType_Nhwc); + } + } + } +} + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(input_block, param->op_parameter_.thread_num_); + int input_start_block = block_per_thread * task_id; + int input_end_block = MSMIN(input_start_block + block_per_thread, input_block); + if (input_start_block >= input_end_block) { + return; + } + input += input_start_block * tile_n * param->deep_; + pack_input += input_start_block * tile_n * param->deep_; + + int cur_row_cnt = MSMIN(block_per_thread * tile_n, param->row_ - input_start_block * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#else + RowMajor2Col12MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#endif + for (int i = input_start_block; i < input_end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight; + const float16_t *cur_bias = bias; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(weight_block, param->op_parameter_.thread_num_); + int weight_start_block = block_per_thread * task_id; + int weight_end_block = MSMIN(weight_start_block + block_per_thread, weight_block); + if (weight_start_block >= weight_end_block) { + return; + } + for (int i = 0; i < input_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight + weight_start_block * C8NUM * param->deep_; + const float16_t *cur_bias = bias + weight_start_block * C8NUM; + for (int j = weight_start_block; j < weight_end_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + int per_thread_num = UP_DIV(output_count, conv_param->thread_num_); + int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num; + NNACL_CHECK_ZERO_RETURN(real_tile); + int output_tile_count = UP_DIV(output_count, real_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM; + float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int out_tile_index = thread_id * real_tile; + int cal_num = output_count - thread_id * real_tile; + cal_num = cal_num > real_tile ? real_tile : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 16. + // For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float16_t *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM); + WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, + out_tile_index, out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM; + for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) { + int real_c = in_channel - c * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM; + float16_t *dst_c = trans_input + c * tile_num * C8NUM; + trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM; + MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float16_t *src_ptr = trans_input; + float16_t *dst_ptr = gemm_out; + float16_t *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#else + RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#endif + MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, + cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } else { + WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h new file mode 100644 index 00000000..6d296b66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/conv_fp16.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONV_FP16_H_ +#define NNACL_FP16_CONV_FP16_H_ + +#include +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +typedef float16_t *TmpBufferAddressFp16; +typedef float16_t *MatricesFp16; + +#ifdef __cplusplus +extern "C" { +#endif +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c new file mode 100644 index 00000000..31e1f08e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.c @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/crop_fp16.h" + +#include + +#include "nnacl_c/crop_parameter.h" + +void Fp16Crop1D(const float16_t *input, float16_t *output, int *out_shape, int64_t *in_offset, int task_id, + int thread_count) { + const int out_batch = out_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const float16_t *in_ptr = input + n + in_offset[0]; + float16_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); +} + +void Fp16Crop2D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const float16_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + float16_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); + } +} + +void Fp16Crop3D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const float16_t *in_ptr = + input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_width); + } + } +} + +void Fp16Crop4D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + const int in_channel = in_shape[3]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + const int out_channel = out_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const float16_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_channel); + } + } + } +} + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num) { + switch (input_dim) { + case 1: + Fp16Crop1D(input, output, out_shape, in_offset, task_id, thread_num); + break; + case 2: + Fp16Crop2D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 3: + Fp16Crop3D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 4: + Fp16Crop4D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + default: + break; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h new file mode 100644 index 00000000..d039dd4e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/crop_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CROP_FP16_H_ +#define NNACL_FP16_CROP_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num); + +#endif // NNACL_FP16_CROP_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c new file mode 100644 index 00000000..5d20218e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.c @@ -0,0 +1,70 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/custom_gru_fp16.h" +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float16_t *input_gate = buffer[1]; + float16_t *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size); + SigmoidFp16(input_gate, input_gate, double_output_size); + ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size); + TanhFp16(input_gate, input_gate, output_size); + ElementSubFp16(init_h, input_gate, hidden_gate, output_size); + ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAddFp16(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h new file mode 100644 index 00000000..c6f95a8d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/custom_gru_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CUSTOM_GRU_FP16_H_ +#define NNACL_FP16_CUSTOM_GRU_FP16_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // NNACL_FP16_CUSTOM_GRU_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c new file mode 100644 index 00000000..6525441a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/deconv_fp16.h" +#include + +void DeConvPostAddC8WithStride(const float16_t *source, float16_t *dest, size_t srcStride, size_t dststride, + size_t count) { + if (count == 0) { + return; + } + + const float16_t *src_ptr = source; + float16_t *dst_ptr = dest; + float16x8_t src1 = vld1q_f16(src_ptr); + float16x8_t dst1 = vld1q_f16(dst_ptr); + float16x8_t src2; + float16x8_t dst2; + size_t i = 1; + while (i < count - 1) { + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + i = i + 2; + src1 = vld1q_f16(src_ptr + srcStride + srcStride); + dst1 = vld1q_f16(dst_ptr + dststride + dststride); + + src_ptr = src_ptr + srcStride + srcStride; + dst_ptr = dst_ptr + dststride + dststride; + } + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + if (i < count) { + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + } +} + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param) { + float16x8_t min_v = vdupq_n_f16(-FLT_MAX); + float16x8_t max_v = vdupq_n_f16(FLT_MAX); + if (conv_param->act_type_ == ActType_Relu) { + min_v = vdupq_n_f16(0.f); + } + if (conv_param->act_type_ == ActType_Relu6) { + min_v = vdupq_n_f16(0.f); + max_v = vdupq_n_f16(6.f); + } + + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); + int in_plane16 = UP_ROUND(input_plane, 16); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane16 * C8NUM; + int src_kh_stride = in_plane16 * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_w_); + + for (int c = 0; c < oc8; c += 8) { + float16_t *dst_ptr = tmp + c * output_plane; + const float16_t *src_ptr = src + c * in_plane16 * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float16_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + + const float16_t *src_in_ptr = src_ptr + ih * src_ih_stride + iw * src_iw_stride; + float16_t *dst_in_ptr = dst_ptr + oh * dst_oh_stride + ow * dst_ow_stride; + + for (int kh = kh_start; kh < kh_end; kh++) { + const float16_t *src_kh_ptr = src_in_ptr + kh * src_kh_stride; + float16_t *dst_kh_ptr = dst_in_ptr + kh * dst_kh_stride; + DeConvPostAddC8WithStride(src_kh_ptr + kw_start * src_kw_stride, dst_kh_ptr + kw_start * dst_kw_stride, + src_kw_stride, dst_kw_stride, kw_end - kw_start); + } // kh + } // iw + } // ih + + /* add bias for current oh*ow*C8 + * write to output data ptr in nhwc format */ + float16x8_t bias_v = vld1q_f16(bias + c); + float16_t *pack_tmp_data = dst_ptr; + for (size_t i = 0; i < output_plane; i++) { + float16x8_t data_v = vld1q_f16(pack_tmp_data); + data_v = vaddq_f16(data_v, bias_v); + data_v = vminq_f16(data_v, max_v); + data_v = vmaxq_f16(data_v, min_v); + vst1q_f16(pack_tmp_data, data_v); + pack_tmp_data += C8NUM; + } + } // oc8 + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h new file mode 100644 index 00000000..5d3a6572 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_FP16_H_ +#define NNACL_FP16_DECONV_FP16_H_ + +#include +#include +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp16/common_func_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_DECONV_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c new file mode 100644 index 00000000..00b9450d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.c @@ -0,0 +1,519 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/deconv_winograd_fp16.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +void DeConvWgInputPackFp16(const float16_t *src_ptr, float16_t *dst_ptr, int channel, int stride) { + int ic4div = channel / C4NUM; + int ic4mod = channel % C4NUM; + const float16_t *src = src_ptr; + float16_t *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { + vst1_f16(dst, vld1_f16(src)); + dst += stride; + src += C4NUM; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < C4NUM; ic_res++) { + dst[ic_res] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM82_A32 +void DeconvWgMergeFp16A32Fun(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r0, %[src_ptr]\n" + "mov r1, %[dst_ptr]\n" + "mov r2, r1\n" + + "vld1.16 {d0}, [r0], %[src_step]\n" + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vld1.16 {d8}, [r0], %[src_step]\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vadd.f16 d8, d8, d10\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d12, d12, d14\n" + "vld1.16 {d0}, [r0], %[src_step]\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d8}, [r0], %[src_step]\n" + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d8, d8, d10\n" + "vadd.f16 d12, d12, d14\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r0", "r1", "r2", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif + +void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float16_t *src_ptr = src; + float16_t *dst_ptr = dst; + size_t cuont8 = count / C8NUM * C8NUM; + int i = 0; + for (; i < cuont8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4h}, [x7], %[src_step]\n" + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "ld1 {v4.4h}, [x7], %[src_step]\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "ld1 {v0.4h}, [x7], %[src_step]\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v4.4h}, [x7], %[src_step]\n" + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM82_A32) + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + DeconvWgMergeFp16A32Fun(src_ptr, dst_ptr, src_step, dst_step); +#else + for (int j = 0; j < C8NUM; j++) { + const float16_t *s = src_ptr + j * src_stride; + float16_t *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + + for (; i < count; i++) { + float16x4_t src_data = vld1_f16(src_ptr); + float16x4_t dst_data = vld1_f16(dst_ptr); + dst_data = vadd_f16(src_data, dst_data); + vst1_f16(dst_ptr, dst_data); + + src_ptr += src_stride; + dst_ptr += dst_stride; + } + return; +} + +void DeConvWgCalWgFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight_buf, float16_t *tmp_buf, + const float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transferred, + const float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, + DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float16_t *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float16_t *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + TiledC4MatmulFp16(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, + deconv_param->oc_div_); + } + + WinogradTransLeftFp16(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float16_t *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float16_t *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + return; +} + +void DeConvWgCalCommFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight, float16_t *tmp_buf, + int h_start, int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float16_t *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + TiledC4MatmulFp16(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float16_t *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float16_t *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } + return; +} + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + int output_channel = conv_param->output_channel_; + int size = conv_param->input_channel_ * output_channel * tmp_kernel_plane; + float16_t *current_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float16_t *src_ic = nhwc_weight + deconv_param->kernel_plane_ * output_channel * ic; + float16_t *dst_ic = current_unit_weight + tmp_kernel_plane * output_channel * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float16_t *src_hw = src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * output_channel; + float16_t *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * output_channel; + memcpy(dst_hw, src_hw, output_channel * sizeof(float16_t)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64]; + float matrix_gt[64]; + float matrix_a[64]; + float matrix_at[64]; + float matrix_b[64]; + float matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.AT_ == NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_at, unit->winograd_.AT_, unit->winograd_.i_ * unit->winograd_.o_); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.BT_ == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_bt, unit->winograd_.BT_, unit->winograd_.o_ * unit->winograd_.o_); + + /* winograd Weight */ + size = conv_param->input_channel_ * output_channel * unit->winograd_.kh_ * unit->winograd_.kw_; + float16_t *winograd_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (winograd_unit_weight == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + free(unit->winograd_.BT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + unit->winograd_.BT_ = NULL; + return NNACL_NULL_PTR; + } + + WinogradWeightTransformFp16(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, C4NUM, + unit->winograd_.kh_, unit->h_size_, output_channel, conv_param->input_channel_, false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float16_t *dst_weight = (float16_t *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float16_t)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * output_channel * tmp_kernel_plane + upi * output_channel + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * C4NUM * deconv_param->ic_up_ + + ic * C4NUM + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_OK; +} + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + float16x4_t zero = vdup_n_f16(0.0f); + + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float16_t *dst_unit = tile_in + unit_index * C4NUM; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float16_t *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { + vst1_f16(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, zero); + } + continue; + } + + const float16_t *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPackFp16(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT) { + return; + } + DeConvWgABuffer *tmp_a = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float16_t *mid_a = (float16_t *)tmp_a->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst_a = (float16_t *)tmp_a->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *tmp_b = (float16_t *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgCalWgFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, + transferred, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, + conv_param, deconv_param); + } else { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * + unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + DeConvWgCalCommFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return; +} + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * C4NUM; + + for (int index = 0; index < calculate_count; ++index) { + const float16_t *src_start = tile_out + index * C4NUM; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float16_t *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * C4NUM + w_start * C4NUM; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float16_t *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float16_t *dst = dst_start + (hi * conv_param->output_w_ + wi) * C4NUM; + DeConvWgMergeFp16(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return; +} + +#ifndef ENABLE_ARM +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + const float16_t *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + float16_t b = B[i * h + y]; + const float16_t *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + const float16_t *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + const float16_t *srcX = srcY + i * unitStep; + float16_t b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, + size_t oc4) { + int dx, sz, dz; + int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float16_t *dst_z = dst + dz * cal_num; + const float16_t *weight_dz = weight + dz * ic4 * 16; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float16_t *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float16_t *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float16_t *src_z = src_dx + sz * src_depth_step; + const float16_t *weight_z = weight_dz + sz * 16; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h new file mode 100644 index 00000000..dbb5bcd1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/deconv_winograd_fp16.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_WINOGRAD_FP16_H_ +#define NNACL_FP16_DECONV_WINOGRAD_FP16_H_ + +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, + size_t oc4); + +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_DECONV_WINOGRAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c new file mode 100644 index 00000000..de3d1dea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/dynamic_quant_fp16.h" + +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max) { +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#else + // avoid to compile optimize. + volatile int count_8 = DOWN_ROUND(count, C8NUM); + CalculateMinMaxCount8Fp16(data, count_8, real_min, real_max); + for (int i = count_8; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h new file mode 100644 index 00000000..1ab6cf5e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/dynamic_quant_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_FP16_H_ +#define NNACL_INT8_DYNAMIC_QUANT_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max); + +#ifdef ENABLE_ARM64 +void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c new file mode 100644 index 00000000..adeb0a7e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/exp_fp16.h" +#include +#include +#include "nnacl_c/errorcode.h" + +#if defined(ENABLE_NEON) +static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { + static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = vmaxq_f16(minv, vminq_f16(input, maxv)); + vst1q_f16(dst, VexpFp16(input)); +} +#endif + +void ExpFp16(const float16_t *src, float16_t *dst, int num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(vld1q_f16(src + i), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + } +} + +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *src = (float16_t *)src_data; + float16_t *dst = (float16_t *)dst_data; + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp16(src + start, dst + start, num); + } else { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->in_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_MULQ_F16(MS_LDQ_F16(src + i), scale), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->out_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_LDQ_F16(src + i), dst + i); + MS_STQ_F16(dst + i, MS_MULQ_F16(MS_LDQ_F16(dst + i), scale)); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h new file mode 100644 index 00000000..d5b30825 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/exp_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_EXP_FP16_H_ +#define NNACL_FP16_EXP_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/exp_parameter.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ExpFp16(const float16_t *src, float16_t *dst, int num); +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef ENABLE_NEON +static inline float16x8_t VexpFp16(float16x8_t input) { + float32x4_t input_low = MS_CVT_F32_F16(vget_low_f16(input)); + float32x4_t input_high = MS_CVT_F32_F16(vget_high_f16(input)); + return vcombine_f16(MS_CVT_F16_F32(VexpFp32(input_low)), MS_CVT_F16_F32(VexpFp32(input_high))); +} +#endif + +static inline void single_exp_fp16(float16_t src, float16_t *dst) { + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; + int integer; + if (src > 0) { + src = MSMIN(88.72283935546875f, src); + integer = (float)src * 1.44269504088896341f + 0.5f; + } else { + src = MSMAX(-87.3365478515625f, src); + integer = (float)src * 1.44269504088896341f - 0.5f; + } + const int shift = 23; + const int bias = 126; + const float factor = 2; + float decimal = (float)src - integer * param[0]; + int int_exp = (integer + bias) << shift; + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + float *tmp = (float *)(&int_exp); + *dst = (float16_t)(*(tmp)*decimal_exp * factor); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_EXP_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c new file mode 100644 index 00000000..a4b32348 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.c @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/fill_fp16.h" + +inline int FillFp16(float16_t *output, int size, float16_t data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h new file mode 100644 index 00000000..c177e9a9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/fill_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_FILL_FP16_H_ +#define NNACL_FP16_FILL_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fill_parameter.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +int FillFp16(float16_t *output, int size, float16_t data); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_FILL_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c new file mode 100644 index 00000000..ff2f9e44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/gru_fp16.h" +#include +#include "nnacl_c/fp16/lstm_fp16.h" +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" + +void GruStepUnitFp16(float16_t *output, float16_t *update_gate, float16_t *reset_gate, float16_t *hidden_buffer, + const float16_t *state_weight, const float16_t *state_bias, float16_t *hidden_state, + float16_t *buffer[4], const GruParameter *gru_param) { + float16_t *packed_state = buffer[2]; + float16_t *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float16_t *state_update_weight = state_weight; + const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float16_t *state_update_gate = state_gate; + float16_t *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float16_t *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float16_t *state_update_bias = state_bias; + const float16_t *state_reset_bias = state_bias + gru_param->hidden_size_; + const float16_t *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + LstmMatMulFp16(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAddFp16(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + + // update update_gate + SigmoidFp16(update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + TanhFp16(hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + float16_t one = 1.0f; + ElementOptSubFp16(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t)); +} + +void GruUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_g, + const float16_t *weight_r, const float16_t *input_bias, const float16_t *state_bias, + float16_t *hidden_state, float16_t *buffer[4], const GruParameter *gru_param, + bool is_backward) { + float16_t *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float16_t *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + gru_param->input_col_align_ * i; + float16_t *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float16_t *update_gate = gate; + float16_t *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float16_t *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float16_t *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, + buffer, gru_param); + } +} + +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param) { + // forward + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_, false); + GruUnidirectionalFp16(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, + gru_param, false); + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float16_t *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float16_t *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float16_t *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float16_t *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + GruUnidirectionalFp16(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float16_t *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h new file mode 100644 index 00000000..ea8a7b71 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/gru_fp16.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRU_H_ +#define NNACL_FP16_GRU_H_ +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c new file mode 100644 index 00000000..a174e46e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.c @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/instance_norm_fp16.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(channel, param->op_parameter_.thread_num_); + int channel_begin = task_id * channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + for (int c = channel_begin; c < channel_end; c++) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c * hw_plane; + float mean = 0.0f; + float square_mean = 0.0f; + + int index = 0; + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t squarev = vmulq_f16(srcv, srcv); + + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + mean += MS_ADDVQ_F32(sum_f32); + + float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); + float32x4_t square_f32 = vcvt_f32_f16(square2); + square_mean += MS_ADDVQ_F32(square_f32); + } + for (; index < hw_plane; index++) { + mean += src[index]; + square_mean += src[index] * src[index]; + } + + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + + float16x8_t gammav = vdupq_n_f16(gamma_data[c]); + float16x8_t betav = vdupq_n_f16(beta_data[c]); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < hw_plane; index++) { + dst[index] = (src[index] - mean) * deno; + dst[index] = dst[index] * gamma_data[c] + beta_data[c]; + } + } + } + return NNACL_OK; +} + +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + int c8_down = channel_end / C8NUM * C8NUM; + int c_res = channel_end - c8_down; + float32x4_t hw_plane_4 = vdupq_n_f32(hw_plane); + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float16_t *src = src_b + c * hw_plane; + const float16_t *src1 = src_b + (c + C8NUM) * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t mean3 = vdupq_n_f32(0.0f); + float32x4_t mean4 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean3 = vdupq_n_f32(0.0f); + float32x4_t square_mean4 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + + float32x4_t srcv01 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv02 = vcvt_f32_f16(vget_high_f16(srcv)); + float32x4_t srcv11 = vcvt_f32_f16(vget_low_f16(srcv1)); + float32x4_t srcv12 = vcvt_f32_f16(vget_high_f16(srcv1)); + mean1 = vaddq_f32(mean1, srcv01); + mean2 = vaddq_f32(mean2, srcv02); + mean3 = vaddq_f32(mean3, srcv11); + mean4 = vaddq_f32(mean4, srcv12); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv01, srcv01)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv02, srcv02)); + square_mean3 = vaddq_f32(square_mean3, vmulq_f32(srcv11, srcv11)); + square_mean4 = vaddq_f32(square_mean4, vmulq_f32(srcv12, srcv12)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t mean_1 = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean3, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean4, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t square_mean_1 = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean3, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean4, hw_plane_4))); + float16x8_t deno = vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); + float16x8_t deno1 = vaddq_f16(vsubq_f16(square_mean_1, vmulq_f16(mean_1, mean_1)), vdupq_n_f16(param->epsilon_)); + deno = 1 / MS_SQRTFX8_F16(deno); + deno1 = 1 / MS_SQRTFX8_F16(deno1); + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t gammav1 = vmulq_f16(vld1q_f16(gamma_data + c + C8NUM), deno1); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + float16x8_t betav1 = vld1q_f16(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + float16x8_t outv1 = vsubq_f16(srcv1, mean_1); + outv = vmulq_f16(outv, gammav); + outv1 = vmulq_f16(outv1, gammav1); + outv = vaddq_f16(outv, betav); + outv1 = vaddq_f16(outv1, betav1); + vst1q_f16(dst + index * channel, outv); + vst1q_f16(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float32x4_t srcv1 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv2 = vcvt_f32_f16(vget_high_f16(srcv)); + mean1 = vaddq_f32(mean1, srcv1); + mean2 = vaddq_f32(mean2, srcv2); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv1, srcv1)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv2, srcv2)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t deno = + vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); // question + deno = 1 / MS_SQRTFX8_F16(deno); // question + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + const float16_t *src = src_b + c8_down * hw_plane + c; + float16_t *dst = dst_b + c; + float mean = 0.0f; + float square_mean = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float16_t tmp = src[index * c_res]; + mean += tmp; + square_mean += tmp * tmp; + } + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(square_mean - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h new file mode 100644 index 00000000..fdbbd065 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/instance_norm_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_INSTANCE_NORM_FP16_H_ +#define NNACL_FP16_INSTANCE_NORM_FP16_H_ + +#include "nnacl_c/instance_norm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_INSTANCE_NORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c new file mode 100644 index 00000000..d6e2b545 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/layer_norm_fp16.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int LayerNormMeanAndSquareFp16(const float16_t *src, int num, float16_t *mean, float16_t *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float sum = 0.0f; + float square_mean = 0.0f; + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + for (int i = 0; i < C8NUM; ++i) { + square_mean += srcv[i] * srcv[i]; + } + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + sum += MS_ADDVQ_F32(sum_f32); + } + for (; index < num; index++) { + sum += src[index]; + square_mean += src[index] * src[index]; + } + *mean = (float16_t)(sum / num); + square_mean = square_mean / num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBetaFp16(float16_t *dst, const float16_t *src, const float16_t *gamma_data, + const float16_t *beta_data, int num, const float16_t mean, const float16_t deno) { + int index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + float16x8_t gammav = vld1q_f16(gamma_data + index); + float16x8_t betav = vld1q_f16(beta_data + index); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float16_t *src_norm = src_data + i * param->norm_inner_size_; + float16_t *dst_norm = dst_data + i * param->norm_inner_size_; + float16_t cur_mean = 0.0f; + float16_t cur_variance = 0.0f; + int ret = LayerNormMeanAndSquareFp16(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float16_t deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float16_t *src_param = src_norm + x * param->params_inner_size_; + float16_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaFp16(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float16_t *gamma = gamma_data + x * param->norm_inner_size_; + const float16_t *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaFp16(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h new file mode 100644 index 00000000..dd5a1992 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/layer_norm_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_LAYER_NORM_FP16_H_ +#define NNACL_FP16_LAYER_NORM_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LAYER_NORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c new file mode 100644 index 00000000..d00a4807 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/log_softmax_fp16.h" +#include +#include +#include "nnacl_c/fp16/softmax_fp16.h" +#include "nnacl_c/fp16/exp_fp16.h" + +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(exp_data + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - log(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += exp(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - log(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h new file mode 100644 index 00000000..f4cf14d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/log_softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LOG_SOFTMAX_FP16_H_ +#define NNACL_FP16_LOG_SOFTMAX_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef ENABLE_NEON +#include +#endif +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel); +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LOG_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c new file mode 100644 index 00000000..d812b6e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.c @@ -0,0 +1,367 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/lstm_fp16.h" +#include +#include +#include "nnacl_c/fp16/activation_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, true); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); +#endif + } +} + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float16_t *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, false); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); +#endif + } +} + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(src_batch, dst_batch, col); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(backward_src_batch, backward_dst_batch, col); + } + } +} + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float16_t)); + } + if (is_bidirectional) { + const float16_t *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); + } + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float16_t res = 0; + const float16_t *input_col = input + r * inner_size; + const float16_t *weight_col = weight + c * inner_size; + int index = 0; + float16x8_t out = vdupq_n_f16(0.0f); + for (; index <= inner_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input_col + index); + float16x8_t in_1 = vld1q_f16(weight_col + index); + out = vfmaq_f16(out, in_1, in_0); + } + float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out)); + float16x4_t add4 = vpadd_f16(add2, add2); + float16x4_t add8 = vpadd_f16(add4, add4); + res += vget_lane_f16(add8, 0); + for (; index < inner_size; index++) { + res += input_col[index] * weight_col[index]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input0 + index); + float16x8_t in_1 = vld1q_f16(input1 + index); + float16x8_t out = vld1q_f16(output + index); + out = vfmaq_f16(out, in_1, in_0); + vst1q_f16(output + index, out); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vld1q_f16(output + index); + vout = MS_FMAQ_N_F16(vout, vin0, input1); + vst1q_f16(output + index, vout); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateStateFp16(float16_t *cell_state, const float16_t *forget_gate, const float16_t *input_gate, + const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size, + float16_t zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAccFp16(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_t *cell_state, float16_t *output_gate, + const float16_t *weight_project, const float16_t *project_bias, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float16_t *state_buffer = buffer[C5NUM]; + float16_t *hidden_buffer = weight_project ? buffer[C3NUM] : hidden_state; + float16_t zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + TanhFp16(cell_state, hidden_buffer, batch * hidden_size); + ElementMulFp16(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float16_t *left_matrix = hidden_buffer; +#ifdef ENABLE_ARM64 + if (batch >= C4NUM) { + left_matrix = buffer[C6NUM]; + RowMajor2ColLadder12MajorFp16(hidden_buffer, left_matrix, batch, hidden_size); + } +#else + if (batch != 1) { + left_matrix = buffer[C6NUM]; + RowMajor2Col16MajorFp16(hidden_buffer, left_matrix, batch, hidden_size, false); + } +#endif + LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, output_size, + batch == 1); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float16_t)); +} + +#ifdef ENABLE_ARM64 +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + MatmulFp16OptV2(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); +} +#else +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + if (is_vec) { + (void)memcpy(c, bias, col * sizeof(float16_t)); + MatMulAccFp16(c, a, b, row, col, deep); + } else { + MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} +#endif + +void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, + int row, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 4; i++) { + const float16_t *weight_i = weight + deep * col_align * i; + const float16_t *bias_i = bias + col_align * i; + float16_t *gate = gate_buffer + row * col * i; + LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); + } +} + +void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate, + float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param) { + float16_t *packed_state = buffer[C2NUM]; + float16_t *state_gate = buffer[C3NUM]; + float16_t *cell_buffer = buffer[C4NUM]; + float16_t *hidden_buffer = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; +#ifdef ENABLE_ARM64 + if (lstm_param->batch_ <= C3NUM) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2ColLadder12MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#else + if (is_vec) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_, false); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#endif + ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update forget_gate + SigmoidFp16(forget_gate, forget_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update cell_gate + TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update cell state + UpdateStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->zoneout_cell_); + + // update output_gate + SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update output + UpdateOutputFp16(hidden_state, output, cell_state, output_gate, weight_project, project_bias, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float16_t)); + } +} + +void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, + const LstmParameter *lstm_param) { + int row_input = lstm_param->seq_len_ * lstm_param->batch_; + for (int i = 0; i < C4NUM; i++) { + const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; +#ifdef ENABLE_ARM64 + MatmulFp16OptV2(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#else + MatMulFp16(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#endif + } +} + +void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, + const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, + bool is_backward) { + float16_t *gate = buffer[1]; + LstmGateCompute(gate, packed_input, weight_i, input_bias, lstm_param); + + float16_t *input_gate = gate; + float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, + weight_project, project_bias, hidden_state, cell_state, buffer, lstm_param); + } +} + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + // forward +#ifdef ENABLE_ARM64 + const float16_t *packed_input = input; + if (lstm_param->batch_ * lstm_param->seq_len_ >= C4NUM) { + float16_t *temp_input = buffer[0]; + RowMajor2ColLadder12MajorFp16(input, temp_input, lstm_param->seq_len_ * lstm_param->batch_, + lstm_param->input_size_); + packed_input = temp_input; + } +#else + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, + false); +#endif + LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, weight_project, project_bias, + hidden_state, cell_state, buffer, lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + const float16_t *backward_weight_project = + weight_project ? weight_project + lstm_param->hidden_size_ * lstm_param->proj_col_align_ : NULL; + float16_t *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_weight_project, project_bias, backward_hidden_state, + backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h new file mode 100644 index 00000000..675643f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/lstm_fp16.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LSTM_FP16_H_ +#define NNACL_FP16_LSTM_FP16_H_ + +#include "nnacl_c/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec); + +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size); + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size); + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LSTM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c new file mode 100644 index 00000000..9d239cde --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.c @@ -0,0 +1,1204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/matmul_fp16.h" + +static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + const float16_t *src = (const float16_t *)src_ptr; + int ci = 0; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float16_t *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 2; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v0.8h}, [x10], x12\n" + "ld1 {v1.8h}, [x10], x12\n" + "ld1 {v2.8h}, [x10], x12\n" + "ld1 {v3.8h}, [x10], x12\n" + "ld1 {v4.8h}, [x10], x12\n" + "ld1 {v5.8h}, [x10], x12\n" + "ld1 {v6.8h}, [x10], x12\n" + "ld1 {v7.8h}, [x10], x12\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [dst_ptr1] "r"(dst_ptr1), [src_ptr1] "r"(src_ptr1), [strid_row] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; + } + } +#endif + } + for (; ri < row; ++ri) { + const float16_t *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; + } + } +} + +static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + int ci = 0; + const float *src = (const float *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 4; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v8.4s, v9.4s}, [x10], x12\n" + "ld1 {v10.4s, v11.4s}, [x10], x12\n" + "ld1 {v12.4s, v13.4s}, [x10], x12\n" + "ld1 {v14.4s, v15.4s}, [x10], x12\n" + "ld1 {v16.4s, v17.4s}, [x10], x12\n" + "ld1 {v18.4s, v19.4s}, [x10], x12\n" + "ld1 {v20.4s, v21.4s}, [x10], x12\n" + "ld1 {v22.4s, v23.4s}, [x10], x12\n" + + "fcvtn v0.4h, v8.4s\n" + "fcvtn2 v0.8h, v9.4s\n" + "fcvtn v1.4h, v10.4s\n" + "fcvtn2 v1.8h, v11.4s\n" + "fcvtn v2.4h, v12.4s\n" + "fcvtn2 v2.8h, v13.4s\n" + "fcvtn v3.4h, v14.4s\n" + "fcvtn2 v3.8h, v15.4s\n" + "fcvtn v4.4h, v16.4s\n" + "fcvtn2 v4.8h, v17.4s\n" + "fcvtn v5.4h, v18.4s\n" + "fcvtn2 v5.8h, v19.4s\n" + "fcvtn v6.4h, v20.4s\n" + "fcvtn2 v6.8h, v21.4s\n" + "fcvtn v7.4h, v22.4s\n" + "fcvtn2 v7.8h, v23.4s\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [dst_ptr1] "r"(dst_ptr1), [src_ptr1] "r"(src_ptr1), [strid_row] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); + } + } +#endif + } + for (; ri < row; ++ri) { + const float *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); + } + } +} + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { + if (src_float16) { + Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col); + } else { + Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col); + } + return; +} + +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r16div = r / 16, r16mod = r % 16; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * 16 + d * 16 + r16mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_16 = UP_ROUND(row, C16NUM); + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C16NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { // common conv and matmul + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { // common deconv + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { // winograd conv + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +#ifdef ENABLE_DEBUG +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + size_t index = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[index] = value; + } + } +} +#endif + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type) { + if (out_type == OutType_C8) { + // common deconv +#ifdef ENABLE_ARM64 + MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); +#else + MatMul12x8Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } else { + // winograd conv(OntType_TileC8) and common conv(OutType_Nhwc) and matmul(OutType_Nhwc) +#ifdef ENABLE_ARM64 + MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#else + MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } + return; +} + +#ifdef ENABLE_ARM64 +// 1*8 X 8*8 -> 1 X 8 +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col) { + int ci = col; + const float16_t *bv_base = b; + + while (ci > 0) { + float16x8_t acc_0 = vdupq_n_f16((float16_t)0.0); + if (bias != NULL) { + acc_0 = vld1q_f16(bias); + bias += C8NUM; + } + + int di = 0; + for (; di < depth - C8NUM + 1; di += C8NUM) { + float16x8_t av = vld1q_f16(a + di); + float16x8_t bv_0; + float16x8_t bv_1; + for (int i = 0; i < C8NUM; i += C2NUM) { + bv_0 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_0, av[i]); // av[i]为向量中的一个值 + bv_base += C8NUM; + + bv_1 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_1, av[i + 1]); // av[i]为向量中的一个值 + bv_base += C8NUM; + } + } + if (di < depth) { + for (; di < depth; ++di) { + float16_t ai = a[di]; + float16x8_t bv0 = vld1q_f16(bv_base); + bv_base += C8NUM; + acc_0 = vfmaq_n_f16(acc_0, bv0, ai); + } + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)); + } + if (act_type == ActType_Relu6) { + acc_0 = vminq_f16(vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)), vdupq_n_f16((float16_t)6.0)); + } + + // only save actual col num data + if (ci < C8NUM) { + for (int i = 0; i < ci; ++i) { + c[i] = acc_0[i]; + } + return; + } + vst1q_f16(c, acc_0); + c += C8NUM; + ci -= C8NUM; + } +} +#endif + +#ifdef ENABLE_ARM82_A32 +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col) { +#ifdef ENABLE_ARM64 + MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); +#else + MatVecMulA32NeonFp16(a, b, c, bias, (int)act_type, depth, col); +#endif +} + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [dst_c] "r"(dst_ptr), [src_c] "r"(src_ptr), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, col); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col) { + // Col12Major ==> Col8Major ==> Col4Major + const float16_t *src_r = src; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C12NUM; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + Transpose12x8ARM64Fp16(src_c, dst_c, col * C2NUM, C24NUM); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (ri < row) { + memcpy(dst_r, src_r, (row - ri) * col * C2NUM); + } +} + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col) { + // Row12 ==> Row8 ==> Row4 + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X4 src_data1 = MS_LD_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM, src_data); + MS_ST_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r + c % C4NUM * row] = src[r * col + c]; + } + } +} + +void RowMajor2ColNMajorFp16srcFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + Row2Col16Block16(src_c, dst_c, col); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2ColNMajorFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col, bool is_fp32_src) { + // Col16Major ==> Col8Major ==> Col4Major + if (!is_fp32_src) { + RowMajor2ColNMajorFp16srcFp16((const float16_t *)src_ptr, dst_ptr, row, col); + return; + } + const float *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + for (int r = 0; r < C16NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C16NUM + r % C16NUM] = src_r[r * col + c]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + for (int r = 0; r < C8NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C8NUM + r % C8NUM] = src_r[r * col + c]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + for (int r = 0; r < C4NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C4NUM + r % C4NUM] = src_r[r * col + c]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ++ri) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2RowNMajorFp16(const void *src_ptr, float16_t *dst, int row, int col, bool is_fp32_src) { + // Row16 ==> Row8 ==> Row4 + if (is_fp32_src) { + const float *src = (const float *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X4 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM), MS_LDQ_F32(cur_src + C8NUM), + MS_LDQ_F32(cur_src + C12NUM)}; + MS_FLOAT16X4X4 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + MS_CVT_F16_F32(src_f32_data.val[2]), + MS_CVT_F16_F32(src_f32_data.val[3]), + }; + MS_ST4_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, res); + } + for (; c <= col - C8NUM; c += C8NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X2 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM)}; + MS_FLOAT16X4X2 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + }; + MS_ST2_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, res); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_CVT_F16_F32(MS_LDQ_F32(src + r * col + c)); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } + return; + } + const float16_t *src = (const float16_t *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } +} +#endif + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + Transpose12x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#elif ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div16 = r / 16; + int r_mod16 = r % 16; + dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col16MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div12 = r / 12; + int r_mod12 = r % 12; + dst[r_div12 * 12 * col + c * 12 + r_mod12] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col12MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / 16; + int c_mod16 = c % 16; + if (is_fp32_src) { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) { + int col_align = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * col + c]; + } + for (; c < col_align; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = (float16_t)0.0; + } + } +} + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / 12; + int c_mod12 = c % 12; + if (is_fp32_src) { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + int down_c8 = col / C8NUM; + int stride = C8NUM * row; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < down_c8; c++) { + MS_FLOAT16X8 src_data = MS_LDQ_F16((const float16_t *)src + r * col + c * C8NUM); + MS_STQ_F16(dst + c * stride + r * C8NUM, src_data); + } + c *= C8NUM; + for (; c < col; c++) { + int c_div8 = c / 8; + int c_mod8 = c % 8; + dst[c_div8 * stride + r * 8 + c_mod8] = ((const float16_t *)src)[r * col + c]; + } + } +} + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + if (is_fp32_src) { + dst[c * row + r] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c * row + r] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8MajorFp16_arm64(const float16_t *src_c, float16_t *dst_c, size_t col) { + size_t stride = col * sizeof(float16_t); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v0.8h, v1.8h\n" + "zip1 v10.8h, v2.8h, v3.8h\n" + "zip2 v11.8h, v2.8h, v3.8h\n" + "zip1 v12.8h, v4.8h, v5.8h\n" + "zip2 v13.8h, v4.8h, v5.8h\n" + "zip1 v14.8h, v6.8h, v7.8h\n" + "zip2 v15.8h, v6.8h, v7.8h\n" + + "trn1 v16.4s, v8.4s, v10.4s\n" + "trn2 v17.4s, v8.4s, v10.4s\n" + "trn1 v18.4s, v12.4s, v14.4s\n" + "trn2 v19.4s, v12.4s, v14.4s\n" + "trn1 v20.4s, v9.4s, v11.4s\n" + "trn2 v21.4s, v9.4s, v11.4s\n" + "trn1 v22.4s, v13.4s, v15.4s\n" + "trn2 v23.4s, v13.4s, v15.4s\n" + + "trn1 v0.2d, v16.2d, v18.2d\n" + "trn1 v1.2d, v17.2d, v19.2d\n" + "trn2 v2.2d, v16.2d, v18.2d\n" + "trn2 v3.2d, v17.2d, v19.2d\n" + "trn1 v4.2d, v20.2d, v22.2d\n" + "trn1 v5.2d, v21.2d, v23.2d\n" + "trn2 v6.2d, v20.2d, v22.2d\n" + "trn2 v7.2d, v21.2d, v23.2d\n" + + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], #64\n" + "st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + return; +} +#endif + +void RowMajor2Col8MajorFp16SrcFp16(const float16_t *src, float16_t *dst, int row, int col) { + int row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + const float16_t *src_r = src; + float16_t *dst_r = dst; + + int ri = 0; + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorFp16_arm64(src_c, dst_c, col); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C8NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } +} + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (!is_fp32_src) { + return RowMajor2Col8MajorFp16SrcFp16(src, dst, row, col); + } + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / 8; + int r_mod8 = r % 8; + dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((const float *)src)[r * col + c]); + } + } +} + +#if defined(ENABLE_DEBUG) && defined(ENABLE_ARM64) +// arm64 matmul +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc) { + int r16 = row / C16NUM * C16NUM; + int r8 = row / C8NUM * C8NUM; + for (int r = 0; r < row; ++r) { + int row_tile = 0; + if (r < r16) { + row_tile = C16NUM; + } else if (r < r8) { + row_tile = C8NUM; + } else { + row_tile = C4NUM; + } + int index = r / row_tile * row_tile * depth + r % row_tile; + for (int t = 0; t < col; ++t) { + int c_div = t / C8NUM; + int c_mod = t % C8NUM; + float16_t res = bias[t]; + for (int d = 0; d < depth; ++d) { + res += a[index + d * row_tile] * b[c_div * depth * C8NUM + d * C8NUM + c_mod]; + } + c[r * col + t] = res; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h new file mode 100644 index 00000000..4a8e94bb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matmul_fp16.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATMUL_FP16_H_ +#define NNACL_FP16_MATMUL_FP16_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2ColNMajorFp16(const void *src, float16_t *dst_ptr, int row, int col, bool is_fp32_src); + +void RowMajor2RowNMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); +void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); + +void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +#ifdef ENABLE_DEBUG +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); +#endif + +void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col); +void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#elif ENABLE_ARM82_A32 +void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#endif + +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type); + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col); + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16); + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATMUL_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c new file mode 100644 index 00000000..518a68ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/matrix_fp16.h" + +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16_t res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +#ifndef ENABLE_ARM64 +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + } + } +} +#endif + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n) { + if (bias == NULL) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = res; + count++; + } + } + } else { + int count = 0; + float16x8_t bias_ptr = vld1q_f16(bias); + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = vaddq_f16(res, bias_ptr); + count++; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h new file mode 100644 index 00000000..e347c242 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/matrix_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATRIX_FP16_H_ +#define NNACL_FP16_MATRIX_FP16_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n); + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n); +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATRIX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c new file mode 100644 index 00000000..06676190 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/one_hot_fp16.h" +#include "nnacl_c/errorcode.h" +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float16_t *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h new file mode 100644 index 00000000..6d10be8d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/one_hot_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_ONE_HOT_FP16_H_ +#define NNACL_FP16_ONE_HOT_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ONE_HOT_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c new file mode 100644 index 00000000..164696e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c @@ -0,0 +1,933 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pack_fp16.h" +#include + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel) { + // nchw to nc8hw8 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float16_t *src_kernel = (float16_t *)src + i * 9; + float16_t *dst_kernel = (float16_t *)dst + (i / 8) * 96 + i % 8; + for (int y = 0; y < 3; y++) { + float16_t g0 = src_kernel[3 * y]; + float16_t g1 = src_kernel[3 * y + 1]; + float16_t g2 = src_kernel[3 * y + 2]; + + dst_kernel[32 * y] = g0; + dst_kernel[32 * y + 8] = (float16_t)0.5 * (g0 + g1 + g2); + dst_kernel[32 * y + 16] = (float16_t)0.5 * (g0 - g1 + g2); + dst_kernel[32 * y + 24] = g2; + } + } +} +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t)); + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float16_t)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src_data = vld1q_f16(src + src_kernel_offset + c); + vst1q_f16(dst + dst_kernel_offset + c * plane, src_data); + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + // Transpose16x8 in arm64 + const int hw_tile = C16NUM; +#else + // Transpose8x8 in others + const int hw_tile = C8NUM; +#endif + int hw_align = plane / hw_tile; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw_align, thread_count) * hw_tile; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw_align = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw_align *= hw_tile; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float16_t *src_batch = (const float16_t *)src + n * batch; + float16_t *dst_batch = (float16_t *)dst + n * batch; + int hw = task_start; + for (; hw < hw_align; hw += hw_tile) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#elif defined(ENABLE_ARM82_A32) + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#else + for (int tr = 0; tr < hw_tile; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < hw_tile; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float16_t *src_ptr = src_batch + hw * channel; + float16_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + return PackNHWCToNCHWFp16(src, dst, batch, channel, plane, task_id, thread_count); +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int c4_channel = ic4 * C4NUM; + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic8 = UP_DIV(channel, C8NUM); + int c8_channel = ic8 * C8NUM; + int nhwc8_batch_unit_offset = ic8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel, + (float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy((float16_t *)dst, (float16_t *)src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp32ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float *src = (const float *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp16ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + const float16_t *batch_src = src + b * plane * c8; + float16_t *batch_dst = dst + b * plane * channel; + + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset = c * plane + p; + batch_dst[dst_offset] = batch_src[src_offset]; + } + } + } +} + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + float16_t *dst_batch = dst + b * plane * c8_channel; + const float *src_batch = src + b * plane * channel; + for (int i = 0; i < plane; i++) { + float16_t *dst_plane = dst_batch + i * c8_channel; + const float *src_plane = src_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = (float16_t)(src[src_index]); + } + } + } + return; +} + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + vst1q_f16(dst + dst_c_offset, vld1q_f16(src + src_c_offset)); + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float16_t *)dst + dst_res_c_offset)[0] = ((float16_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float *dst_plane = dst_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float16_t *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float16_t *dst_plane = dst_batch + i * channel; + memcpy(dst_plane, src_plane, channel * sizeof(float16_t)); + } + } +} + +#ifdef ENABLE_ARM82_A32 +inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src]\n" + "mov r12, %[dst]\n" + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vst1.16 {q0}, [r12], %[dst_stride]\n" + "vst1.16 {q2}, [r12], %[dst_stride]\n" + "vst1.16 {q4}, [r12], %[dst_stride]\n" + "vst1.16 {q6}, [r12], %[dst_stride]\n" + + "vst1.16 {q8}, [r12], %[dst_stride]\n" + "vst1.16 {q10}, [r12], %[dst_stride]\n" + "vst1.16 {q12}, [r12], %[dst_stride]\n" + "vst1.16 {q14}, [r12], %[dst_stride]\n" + + : + : [dst] "r"(dst), [src] "r"(src), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} + +inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vld1.16 {q1}, [r10], %[src_stride]\n" + "vld1.16 {q3}, [r10], %[src_stride]\n" + "vld1.16 {q5}, [r10], %[src_stride]\n" + "vld1.16 {q7}, [r10], %[src_stride]\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vtrn.16 d2, d6\n" + "vtrn.16 d3, d7\n" + "vtrn.16 d10, d14\n" + "vtrn.16 d11, d15\n" + + "vtrn.32 d2, d10\n" + "vtrn.32 d6, d14\n" + "vtrn.32 d3, d11\n" + "vtrn.32 d7, d15\n" + + "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" + "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" + "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" + "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" + + "vswp d3, d18\n" + "vswp d7, d22\n" + "vswp d11, d26\n" + "vswp d15, d30\n" + + "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" + "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" + "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" + "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_ARM64 +inline void Transpose4x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + dst_stride += dst_stride; + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v4.8h, v0.8h, v1.8h\n" + "zip1 v5.8h, v2.8h, v3.8h\n" + + "trn1 v6.4s, v4.4s, v5.4s\n" + "trn2 v7.4s, v4.4s, v5.4s\n" + + "trn1 v24.2d, v6.2d, v7.2d\n" + "trn2 v25.2d, v6.2d, v7.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + + "trn1 v10.4s, v8.4s, v9.4s\n" + "trn2 v11.4s, v8.4s, v9.4s\n" + + "trn1 v26.2d, v10.2d, v11.2d\n" + "trn2 v27.2d, v10.2d, v11.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride), + [tow_dst_stride] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v24", "v25", "v26", + "v27"); +} + +inline void Transpose8x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v26.2d, v20.2d, v22.2d\n" + "trn1 v25.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v28.2d, v12.2d, v14.2d\n" + "trn2 v30.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v28.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v29.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v30.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v31.8h}, [x10], %[tow_dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride), + [tow_dst_stride] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} + +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { +#ifdef ENABLE_DEBUG + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * C12NUM + tr] = src_ptr[tr * col + tc]; + } + } +#else + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#endif +} + +inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + "ld1 {v12.8h}, [x10], %[src_stride]\n" + "ld1 {v13.8h}, [x10], %[src_stride]\n" + "ld1 {v14.8h}, [x10], %[src_stride]\n" + "ld1 {v15.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [src_stride] "r"(src_stride), [dst_stride] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h new file mode 100644 index 00000000..d2b3d0f7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_PACK_FP16_H_ +#define NNACL_FP16_PACK_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp32ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp16ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel); + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel); + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +#ifdef ENABLE_ARM82_A32 +void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); + +void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM64 +void Transpose4x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose8x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride); +void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PACK_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c new file mode 100644 index 00000000..ecee18ad --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pad_fp16.h" +#include "nnacl_c/common_func.h" + +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num) { + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float16_t *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float16_t *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * sizeof(float16_t)); + } + } + } + } + } +} + +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, padding, mirror_offset)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h new file mode 100644 index 00000000..d1666622 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pad_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PAD_FP16_H_ +#define NNACL_FP16_PAD_FP16_H_ + +#include "nnacl_c/fp32/pad_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num); +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c new file mode 100644 index 00000000..fa72c846 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/pooling_fp16.h" +#include +#include "nnacl_c/errorcode.h" + +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int c8 = channel / C8NUM; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + +#ifdef ENABLE_NEON + MS_FLOAT16X8 min_value = MS_MOVQ_F16(min); + MS_FLOAT16X8 max_value = MS_MOVQ_F16(max); +#endif + + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < pooling_args->output_batch_; batch++) { + const float16_t *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float16_t *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float16_t *src_plane_ptr = src_b_ptr; + float16_t *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + + for (int ci = 0; ci < c8; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci * C8NUM; + float16_t *dst_c_ptr = dst_plane_ptr + ci * C8NUM; +#ifdef ENABLE_NEON + MS_FLOAT16X8 tmp_avg = MS_MOVQ_F16(0); +#else + float16_t tmp_avg0 = 0; + float16_t tmp_avg1 = 0; + float16_t tmp_avg2 = 0; + float16_t tmp_avg3 = 0; + float16_t tmp_avg4 = 0; + float16_t tmp_avg5 = 0; + float16_t tmp_avg6 = 0; + float16_t tmp_avg7 = 0; +#endif + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = MS_ADDQ_F16(tmp_avg, MS_LDQ_F16(src_win_ptr)); +#else + tmp_avg0 += src_win_ptr[0]; + tmp_avg1 += src_win_ptr[1]; + tmp_avg2 += src_win_ptr[2]; + tmp_avg3 += src_win_ptr[3]; + tmp_avg4 += src_win_ptr[4]; + tmp_avg5 += src_win_ptr[5]; + tmp_avg6 += src_win_ptr[6]; + tmp_avg7 += src_win_ptr[7]; +#endif + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } +#ifdef ENABLE_NEON + tmp_avg = MS_DIVQ_F16(tmp_avg, MS_MOVQ_F16((float16_t)real_count)); + MS_STQ_F16(dst_c_ptr, MS_MINQ_F16(MS_MAXQ_F16(tmp_avg, min_value), max_value)); +#else + dst_c_ptr[0] = MSMIN(MSMAX(tmp_avg0 / (float16_t)real_count, min), max); + dst_c_ptr[1] = MSMIN(MSMAX(tmp_avg1 / (float16_t)real_count, min), max); + dst_c_ptr[2] = MSMIN(MSMAX(tmp_avg2 / (float16_t)real_count, min), max); + dst_c_ptr[3] = MSMIN(MSMAX(tmp_avg3 / (float16_t)real_count, min), max); + dst_c_ptr[4] = MSMIN(MSMAX(tmp_avg4 / (float16_t)real_count, min), max); + dst_c_ptr[5] = MSMIN(MSMAX(tmp_avg5 / (float16_t)real_count, min), max); + dst_c_ptr[6] = MSMIN(MSMAX(tmp_avg6 / (float16_t)real_count, min), max); + dst_c_ptr[7] = MSMIN(MSMAX(tmp_avg7 / (float16_t)real_count, min), max); +#endif + } // c8 loop + int channel_s = c8 * C8NUM; + for (int ci = channel_s; ci < channel; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci; + float16_t *dst_c_ptr = dst_plane_ptr + ci; + float16_t tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } + tmp_avg = tmp_avg / (float16_t)real_count; + tmp_avg = fmax(tmp_avg, min); + tmp_avg = fmin(tmp_avg, max); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +void MaxPoolingC8Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; +#ifdef ENABLE_NEON + float16x8_t min_value = vdupq_n_f16(min); + float16x8_t max_value = vdupq_n_f16(max); +#endif + for (int j = 0; j < c8; j++) { + int in_channel_offset = in_batch_offset + j * C8NUM; + int out_channel_offset = out_plane_offset + j * C8NUM; +#ifdef ENABLE_NEON + float16x8_t tmp_max = vdupq_n_f16(min); +#else + float16_t tmp_max[8] = {min, min, min, min, min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C8NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, min_value); + tmp_max = vminq_f16(tmp_max, max_value); + vst1q_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C8NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + *(output_ptr + out_channel_offset + l) = tmp_max[l]; + } +#endif + } // c8 loop +} + +void MaxPoolingC4Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; +#ifdef ENABLE_NEON + float16x4_t min_value2 = vdup_n_f16(min); + float16x4_t max_value2 = vdup_n_f16(max); +#endif + int c4_offset = c8 * C8NUM; + for (int j = 0; j < c4; j++) { + int in_channel_offset = in_batch_offset + c4_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + c4_offset + j * C4NUM; +#ifdef ENABLE_NEON + float16x4_t tmp_max = vdup_n_f16(min); +#else + float16_t tmp_max[4] = {min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C4NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, min_value2); + tmp_max = vmin_f16(tmp_max, max_value2); + vst1_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C4NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + output_ptr[out_channel_offset + l] = tmp_max[l]; + } +#endif + } // c4 loop +} +void MaxPoolingC1Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; + int channel_s = c8 * C8NUM + c4 * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float16_t tmp_max = -FLT_MAX; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + } // win_w loop + } // win_h loop + tmp_max = fmax(tmp_max, min); + tmp_max = fmin(tmp_max, max); + output_ptr[out_channel_offset] = tmp_max; + } // channel_res loop +} + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int output_batch = pooling_args->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + + // input channel is equal to output channel + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + MaxPoolingC8Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC4Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC1Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h new file mode 100644 index 00000000..6f42e09e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/pooling_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POOLING_FP16_H_ +#define NNACL_FP16_POOLING_FP16_H_ + +#include +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_POOLING_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c new file mode 100644 index 00000000..c2a9d157 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.c @@ -0,0 +1,117 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/power_fp16.h" +#include "nnacl_c/errorcode.h" + +#if defined(ENABLE_NEON) +float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { + int tmp = (int)(*(float16_t *)exponent); + int exp = abs(tmp); + float16x8_t result = vmovq_n_f16(1.0f); + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + if (tmp >= 0) { + return result; + } + return 1 / result; +} +#endif + +float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { + int tmp = *(float16_t *)exponent; + int exp = abs(tmp); + float16_t result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return tmp >= 0 ? result : 1 / result; +} + +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#if defined(ENABLE_NEON) + PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; +#endif + + if (CheckIntegerFp16(*exponent)) { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = OptimizedPowerSimdFp16; +#endif + PowerScalarFunFp16_ = OptimizedPowerScalarFp16; + } else { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = StdPowerSimdFp16; +#endif + PowerScalarFunFp16_ = StdPowerScalarFp16; + } + int i = 0; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t result = PowerSimdFunFp16_(scale_8 * vld1q_f16(input + i) + shift_8, exponent); + vst1q_f16(output + i, result); + } +#endif + for (; i < len; ++i) { + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent); + } +} + +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + int i = 0; + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t tmp_8 = scale_8 * vld1q_f16(input + i) + shift_8; + for (int j = 0; j < 8; ++j) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i + j]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i + j] = PowerScalarFunFp16_(tmp_8[j], exponent + i + j); + } + } +#endif + for (; i < len; ++i) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent + i); + } +} + +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFunFp16 PowerFunFp16_ = NULL; + PowerFunFp16_ = broadcast ? PowerBroadCastFp16 : PowerSingleFp16; + PowerFunFp16_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h new file mode 100644 index 00000000..a46139e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/power_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POWER_FP16_H_ +#define NNACL_FP16_POWER_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/pow_parameter.h" + +#if defined(ENABLE_NEON) +typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); +#endif +typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); +typedef void (*PowerFunFp16)(const float16_t *, const float16_t *, float16_t *, int, float, float); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckIntegerFp16(float16_t f) { return floorf(f) == f; } + +static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { + return powf(x, *(float16_t *)exponent); +} + +#if defined(ENABLE_NEON) +static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { + float16x8_t result; + result[0] = powf(x[0], *(float16_t *)exponent); + result[1] = powf(x[1], *(float16_t *)exponent); + result[2] = powf(x[2], *(float16_t *)exponent); + result[3] = powf(x[3], *(float16_t *)exponent); + result[4] = powf(x[4], *(float16_t *)exponent); + result[5] = powf(x[5], *(float16_t *)exponent); + result[6] = powf(x[6], *(float16_t *)exponent); + result[7] = powf(x[7], *(float16_t *)exponent); + return result; +} +#endif +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast); +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_POWER_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c new file mode 100644 index 00000000..0a062adc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/prelu_fp16.h" + +#ifdef ENABLE_ARM64 +static inline void PReluFp164x32(const float16_t *in, float16_t *out, const float16_t *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12]\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11], %[step]\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11]\n" + : + : [in] "r"(in), [out] "r"(out), [cur_slope] "r"(cur_slope), [step] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i <= end - C4NUM; i += C4NUM) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; + for (; j <= channel - C32NUM; j += C32NUM) { + const float16_t *in = cur_in + j; + float16_t *out = cur_out + j; + const float16_t *cur_slope = slope + j; + size_t step = channel * sizeof(float16_t); + PReluFp164x32(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; +#ifdef ENABLE_NEON + for (; j <= channel - C8NUM; j += C8NUM) { + float16x8_t in = vld1q_f16(cur_in + j); + float16x8_t s = vld1q_f16(slope + j); + float16x8_t mul = vmulq_f16(in, s); + uint16x8_t mask = vcleq_f16(in, vmovq_n_f16(0.0f)); + vst1q_f16(cur_out + j, vbslq_f16(mask, mul, in)); + } +#endif + for (; j < channel; j++) { + cur_out[j] = cur_in[j] > 0 ? cur_in[j] : cur_in[j] * slope[j]; + } + } +} + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end) { + int i = start; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t slope_data = vdupq_n_f16(slope); + for (; i <= end - C8NUM; i += C8NUM) { + float16x8_t src_tmp = vld1q_f16(input + i); + float16x8_t mul_tmp = vmulq_f16(src_tmp, slope_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(output + i, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h new file mode 100644 index 00000000..01a12799 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/prelu_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PRELU_FP16_H_ +#define NNACL_FP16_PRELU_FP16_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel); + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PRELU_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c new file mode 100644 index 00000000..cf5ab3d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.c @@ -0,0 +1,290 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/errorcode.h" + +#ifdef ENABLE_ARM64 +void Int8ToFp16_arm64(const int8_t *quant_values, float16_t *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "fcvtn v4.4h, v0.4s\n" + "fcvtn2 v4.8h, v1.4s\n" + "fcvtn v5.4h, v2.4s\n" + "fcvtn2 v5.8h, v3.4s\n" + + "st1 {v4.8h, v5.8h}, [%[dst]], #32\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "fcvtn v4.4h, v9.4s\n" + "str h4, [%[dst]], #2\n" + "bne 1b\n" + + "2:\n" + + : + : [quant_values] "r"(quant_values), [dst] "r"(dst), [scale] "r"(scale), [zp32] "r"(zp), [size] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Int8ToFp16_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +void Fp16ToInt8_arm64(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + const float one = 1.0f; + const float ivs = one / scale; + const int32_t min_value = -128; + const int32_t max_value = 127; + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, wzr\n" + "beq 3f\n" + + "dup v28.4s, %w[ivs]\n" + "dup v29.4s, %w[min_value]\n" + "dup v30.4s, %w[max_value]\n" + + "cmp w8, #32\n" + "blt 2f\n" + "1:\n" // loop 32 + "subs w8, w8, #32\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[real_values]], #64\n" + "fcvtl v8.4s, v0.4h\n" + "fcvtl2 v9.4s, v0.8h\n" + "fcvtl v10.4s, v1.4h\n" + "fcvtl2 v11.4s, v1.8h\n" + "fcvtl v12.4s, v2.4h\n" + "fcvtl2 v13.4s, v2.8h\n" + "fcvtl v14.4s, v3.4h\n" + "fcvtl2 v15.4s, v3.8h\n" + + "dup v16.4s, %w[zp]\n" + "dup v17.4s, %w[zp]\n" + "dup v18.4s, %w[zp]\n" + "dup v19.4s, %w[zp]\n" + "dup v20.4s, %w[zp]\n" + "dup v21.4s, %w[zp]\n" + "dup v22.4s, %w[zp]\n" + "dup v23.4s, %w[zp]\n" + "scvtf v16.4s, v16.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v23.4s, v23.4s\n" + + "fmla v16.4s, v8.4s, v28.4s\n" + "fmla v17.4s, v9.4s, v28.4s\n" + "fmla v18.4s, v10.4s, v28.4s\n" + "fmla v19.4s, v11.4s, v28.4s\n" + "fmla v20.4s, v12.4s, v28.4s\n" + "fmla v21.4s, v13.4s, v28.4s\n" + "fmla v22.4s, v14.4s, v28.4s\n" + "fmla v23.4s, v15.4s, v28.4s\n" + + "fcvtas v8.4s, v16.4s\n" + "fcvtas v9.4s, v17.4s\n" + "fcvtas v10.4s, v18.4s\n" + "fcvtas v11.4s, v19.4s\n" + "fcvtas v12.4s, v20.4s\n" + "fcvtas v13.4s, v21.4s\n" + "fcvtas v14.4s, v22.4s\n" + "fcvtas v15.4s, v23.4s\n" + + "smax v8.4s, v8.4s, v29.4s\n" + "smax v9.4s, v9.4s, v29.4s\n" + "smax v10.4s, v10.4s, v29.4s\n" + "smax v11.4s, v11.4s, v29.4s\n" + "smax v12.4s, v12.4s, v29.4s\n" + "smax v13.4s, v13.4s, v29.4s\n" + "smax v14.4s, v14.4s, v29.4s\n" + "smax v15.4s, v15.4s, v29.4s\n" + "smin v8.4s, v8.4s, v30.4s\n" + "smin v9.4s, v9.4s, v30.4s\n" + "smin v10.4s, v10.4s, v30.4s\n" + "smin v11.4s, v11.4s, v30.4s\n" + "smin v12.4s, v12.4s, v30.4s\n" + "smin v13.4s, v13.4s, v30.4s\n" + "smin v14.4s, v14.4s, v30.4s\n" + "smin v15.4s, v15.4s, v30.4s\n" + + "sqxtn v16.4h, v8.4s\n" + "sqxtn2 v16.8h, v9.4s\n" + "sqxtn v17.4h, v10.4s\n" + "sqxtn2 v17.8h, v11.4s\n" + "sqxtn v18.4h, v12.4s\n" + "sqxtn2 v18.8h, v13.4s\n" + "sqxtn v19.4h, v14.4s\n" + "sqxtn2 v19.8h, v15.4s\n" + "sqxtn v20.8b, v16.8h\n" + "sqxtn2 v20.16b, v17.8h\n" + "sqxtn v21.8b, v18.8h\n" + "sqxtn2 v21.16b, v19.8h\n" + + "st1 {v20.16b, v21.16b}, [%[quant_values]], #32\n" + + "beq 3f\n" + "cmp w8, #32\n" + "bge 1b\n" + + "2:\n" // 1 by 1 + "scvtf s10, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr h0, [%[real_values]], #2\n" + "fcvt s0, h0\n" + "fmul s0, s0, s28\n" + "fadd s0, s0, s10\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v29.4s\n" + "smin v0.4s, v0.4s, v30.4s\n" + "sqxtn v0.4h, v0.4s\n" + "sqxtn v0.8b, v0.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + "bne 2b\n" + + "3:\n" + : + : [size] "r"(size), [ivs] "r"(ivs), [real_values] "r"(real_values), [quant_values] "r"(quant_values), [zp] "r"(zp), + [min_value] "r"(min_value), [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v28", "v29", "v30"); +} +#endif + +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp16ToInt8_arm64(real_values, quant_values, scale, zp, size); +#else + const int8_t min_value = -128; + const int8_t max_value = 127; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + continue; + } + if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > max_value) { + quant_values[i] = max_value; + } else if (temp < min_value) { + quant_values[i] = min_value; + } else { + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + uint8_t zp_ = (uint8_t)zp; + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp_) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf((float)real_values[i])) { + quant_values[i] = 255; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > 255.0f) { + quant_values[i] = 255; + } else if (temp < 0.0f) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h new file mode 100644 index 00000000..b45b16a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/quant_dtype_cast_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_QUANTDTYPECAST_FP16_H_ +#define NNACL_FP16_QUANTDTYPECAST_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size); + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c new file mode 100644 index 00000000..ce1e26d0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/ragged_range_fp16.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param) { + splits[0] = 0; + for (int i = 0; i < param->rows_; i++) { + float16_t start = param->starts_is_scalar_ ? starts[0] : starts[i]; + float16_t limit = param->limits_is_scalar_ ? limits[0] : limits[i]; + float16_t delta = param->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float16_t)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h new file mode 100644 index 00000000..91088cfe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/ragged_range_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RAGGED_RANGE_FP16_H_ +#define NNACL_FP16_RAGGED_RANGE_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/ragged_range.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param); + +#endif // NNACL_FP16_RAGGED_RANGE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h new file mode 100644 index 00000000..9eb9d833 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/range_fp16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RANGE_FP16_H_ +#define NNACL_FP16_RANGE_FP16_H_ + +#include "nnacl_c/op_base.h" + +void RangeFp16(float16_t *output_ptr, float16_t start, float16_t delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP16_RANGE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c new file mode 100644 index 00000000..5c9f4ac4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "nnacl_c/fp16/reduce_fp16.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (axis_size == 0) { + return NNACL_ERR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = 0.0; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = (float16_t)(tmp / axis_size); + } + } + return NNACL_OK; +} + +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = -FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 65504; // fp16 max value + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 1.0f; + for (i = 0; i < axis_size; i++) { + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + tmp = vaddq_f16(tmp, vld1q_f16(inner_src + k * inner_size)); + } + vst1q_f16(inner_dst, tmp); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + float16x8_t src = vld1q_f16(inner_src + k * inner_size); + tmp = MS_FMAQ_F16(tmp, src, src); + } + vst1q_f16(inner_dst, MS_SQRTFX8_F16(tmp)); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size] * inner_src[k * inner_size]; + } + dst_data[j] = sqrtf(tmp); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h new file mode 100644 index 00000000..638f76de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/reduce_fp16.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_REDUCE_FP16_H_ +#define NNACL_FP16_REDUCE_FP16_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_REDUCE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c new file mode 100644 index 00000000..a5b0b318 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.c @@ -0,0 +1,380 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp16/resize_fp16.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +void CalculateCoordinateFp16(float16_t out, int in, int *bottom, int *top, float16_t *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float16_t top_weight = (float16_t)out - (float16_t)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFuncFp16(float16_t a, float16_t x, float16_t *weight) { + float16_t abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubicFp16(float16_t out, int in, int *index, float16_t *weights, float16_t a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float16_t distance[4] = {-1, 0, 1, 2}; + float16_t tmp_dis = out - (float16_t)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFuncFp16(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateCoordinateFp16(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateCoordinateFp16(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubicFp16(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubicFp16(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int InterpRowFp16(const float16_t *src_line, float16_t *linear_output, int new_width, const float16_t *x_left_weights, + const int *x_lefts, const int *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t left_w_8 = vdupq_n_f16(x_left_weights[w]); + float16x8_t right_w_8 = vdupq_n_f16(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t left = vld1q_f16(src_line + x_lefts[w] * in_c + c); + float16x8_t right = vld1q_f16(src_line + x_rights[w] * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(left, left_w_8), vmulq_f16(right, right_w_8)); + vst1q_f16(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float16_t left = src_line[left_w_offset + c]; + float16_t right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpColFp16(const float16_t *bottom_line, const float16_t *top_line, float16_t *output, int new_width, + float16_t y_bottom_weight, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t bottom_w_8 = vdupq_n_f16(y_bottom_weight); + float16x8_t top_w_8 = vdupq_n_f16(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t bottom = vld1q_f16(bottom_line + w * in_c + c); + float16x8_t top = vld1q_f16(top_line + w * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(bottom, bottom_w_8), vmulq_f16(top, top_w_8)); + vst1q_f16(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float16_t bottom = bottom_line[w * in_c + c]; + float16_t top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void BilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_bottom, const int *y_top, const int *x_left, const int *x_right, + const float16_t *y_bottom_weight, const float16_t *x_left_weight, float16_t *line0, float16_t *line1, + const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float16_t *const cache_line_ptr[2] = {line0, line1}; + float16_t *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float16_t *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRowFp16(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpColFp16(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float16_t *input = input_data + b * in_h * in_w * in_c; + float16_t *output = output_data + b * new_height * new_width * in_c; + BilinearFp16(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRowFp16(const float16_t *src, float16_t *dst, const float16_t *weights, const int *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float16_t *weight = weights + 4 * w; + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src + lefts[4 * w] * channel; + const float16_t *src1_w = src + lefts[4 * w + 1] * channel; + const float16_t *src2_w = src + lefts[4 * w + 2] * channel; + const float16_t *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t weight0_vec_8 = vdupq_n_f16(weight[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weight[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weight[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst0 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst1 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst2 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst3 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst3, vaddq_f16(dst2, vaddq_f16(dst1, dst0))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpColFp16(const float16_t *src, float16_t *dst, const float16_t *weights, int width, int channel) { + const float16_t *src0 = src; + const float16_t *src1 = src + width * channel; + const float16_t *src2 = src + 2 * width * channel; + const float16_t *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src0 + w * channel; + const float16_t *src1_w = src1 + w * channel; + const float16_t *src2_w = src2 + w * channel; + const float16_t *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_NEON + float16x8_t weight0_vec_8 = vdupq_n_f16(weights[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weights[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weights[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst1 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst2 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst3 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst4 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst4, vaddq_f16(dst3, vaddq_f16(dst1, dst2))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void BicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_tops, const int *x_lefts, const float16_t *y_weights, const float16_t *x_weights, + float16_t *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRowFp16(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpColFp16(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float16_t *input = input_data + b * input_cube_per_batch; + float16_t *output = output_data + b * output_cube_per_batch; + BicubicFp16(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float16_t actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float16_t actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float16_t)); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h new file mode 100644 index 00000000..a0fecbe1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/resize_fp16.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RESIZE_FP16_H_ +#define NNACL_FP16_RESIZE_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/resize_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/fp32/resize_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights); + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff); + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end); + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end); + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_RESIZE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c new file mode 100644 index 00000000..579aa6d5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.c @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/scale_fp16.h" + +void Fp16ScaleInner(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void Fp16ScaleAxis(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void Fp16ScaleAxisRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void Fp16ScaleAxisRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h new file mode 100644 index 00000000..516bfcb9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/scale_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SCALE_FP16_H_ +#define NNACL_FP16_SCALE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SCALE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c new file mode 100644 index 00000000..58b4ac91 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.c @@ -0,0 +1,134 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/softmax_fp16.h" +#include +#include "nnacl_c/fp16/exp_fp16.h" + +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int j = 0; +#ifdef ENABLE_NEON + float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j); + max_8 = vmaxq_f16(max_8, input_8); + } + float16_t max = MS_MAXVQ_F16(max_8); +#else + float16_t max = -FLT_MAX; +#endif + for (; j < channel; j++) { + float16_t input = src[cur_batch_offset + j]; + if (input > max) { + max = input; + } + } + int k = 0; +#ifdef ENABLE_NEON + int count2 = (channel / C8NUM) * C8NUM; + for (; k < count2; k += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); + float16x8_t output_8 = vsubq_f16(input_8, vdupq_n_f16(max)); + vst1q_f16(dst + cur_batch_offset + k, output_8); + } +#endif + for (; k < channel; k++) { + int offset = cur_batch_offset + k; + dst[offset] = src[offset] - max; + } + } +} + +void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0.0f; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(src + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += src[cur_batch_offset + j]; + } + int k = 0; +#ifdef ENABLE_NEON + const float16_t div = 1.0f / sum; + for (; k < count; k += C8NUM) { + vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); + } +#endif + for (; k < channel; k++) { + dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum; + } + } +} + +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, dst, batch * channel); + SumAndDivFp16(dst, dst, batch, channel); +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h new file mode 100644 index 00000000..cf5cf43e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SOFTMAX_FP16_H_ +#define NNACL_FP16_SOFTMAX_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel); +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape); +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c new file mode 100644 index 00000000..3a2c6639 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.c @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/sparse_to_dense_fp16.h" +#include "nnacl_c/errorcode.h" + +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, + int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h new file mode 100644 index 00000000..7f225dba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/sparse_to_dense_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ +#define NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ + +#include "nnacl_c/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, int task_id); +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c new file mode 100644 index 00000000..d4718e37 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.c @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/splice_fp16.h" +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float16_t *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < param->context_dim_; ++off) { + int r_off = param->forward_indexes_[forward_index]; + forward_index++; + const float16_t *tmp_src_data = src_data + r_off * src_col; + float16_t *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float16_t)); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h new file mode 100644 index 00000000..2c2272fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/splice_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SPLICE_FP16_H_ +#define NNACL_FP16_SPLICE_FP16_H_ +#include +#include "nnacl_c/splice_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_SPLICE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c new file mode 100644 index 00000000..feeeddc9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/topk_fp16.h" + +int TopkFp16DescendCmp(const void *a, const void *b) { + float16_t sub = ((const TopkFp16Node *)b)->element - ((const TopkFp16Node *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +int TopkFp16IndexSortCmp(const void *a, const void *b) { + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkFp16Node *top_map = (TopkFp16Node *)parameter->topk_node_list_; + + float16_t *cur_input_data = (float16_t *)input_data; + float16_t *cur_output_data = (float16_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), TopkFp16DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), TopkFp16IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h new file mode 100644 index 00000000..4054851d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/topk_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_TOPK_FP16_H_ +#define NNACL_FP16_TOPK_FP16_H_ + +#include "nnacl_c/fp32/topk_fp32.h" +#include "nnacl_c/op_base.h" + +typedef struct TopkFp16Node { + float16_t element; + int32_t index; +} TopkFp16Node; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TOPK_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c new file mode 100644 index 00000000..d211eda1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/transpose_fp16.h" +#include +#include "nnacl_c/errorcode.h" + +void Fp16TransposeDim2(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void Fp16TransposeDim3(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void Fp16TransposeDim4(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void Fp16TransposeDim5(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void Fp16TransposeDim6(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(perm); + NNACL_CHECK_NULL_RETURN_VOID(strides); + NNACL_CHECK_NULL_RETURN_VOID(out_strides); + NNACL_CHECK_ZERO_RETURN(thread_num); + + size_t data_size = (*out_strides) * output_shape[0]; + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(perm); + NNACL_CHECK_NULL_RETURN_ERR(strides); + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + Fp16TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + Fp16TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + Fp16TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + Fp16TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + Fp16TransposeDim6(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h new file mode 100644 index 00000000..36ba018b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/transpose_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_TRANSPOSE_FP16_H_ +#define NNACL_FP16_TRANSPOSE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TransposeDimsFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num); +int DoTransposeFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TRANSPOSE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c new file mode 100644 index 00000000..cb876aa9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/unique_fp16.h" + +int FindFp16(const float16_t *array, int len, float16_t target) { + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindFp16(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h new file mode 100644 index 00000000..c5d7defa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/unique_fp16.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_UNIQUE_FP16_H +#define NNACL_FP16_UNIQUE_FP16_H + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_UNIQUE_FP16_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c new file mode 100644 index 00000000..0ff524c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/fp16/common_func_fp16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast) { + NNACL_CHECK_NULL_RETURN_NULL(t); + if (t->data_type_ == kNumberTypeFloat16) { + return t->data_; + } + if (t->data_type_ == kNumberTypeFloat32) { + int ele_num = NNACLGetElementNum(t); + void *fp16_data = env->Alloc(env->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fp16_data); + if (cast) { + Float32ToFloat16((float *)t->data_, (float16_t *)fp16_data, ele_num); + } + return fp16_data; + } + return NULL; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h new file mode 100644 index 00000000..223f9232 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/utils_fp16.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_UTILS_FP16_H_ +#define NNACL_FP16_UTILS_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast); + +#endif // NNACL_FP16_UTILS_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c new file mode 100644 index 00000000..eb458e26 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16/where_fp16.h" +#include "nnacl_c/common_func.h" + +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h new file mode 100644 index 00000000..6d927e0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/where_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_WHERE_FP16_H_ +#define NNACL_FP16_WHERE_FP16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WHERE_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c new file mode 100644 index 00000000..452d8ab4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.c @@ -0,0 +1,360 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/winograd_transform_fp16.h" + +void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e, + int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; + + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); + } + + // get real input block with padding + if (real_c == C8NUM) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f16(dst_addr, vld1q_f16(src_addr)); +#else + for (int k = 0; k < C8NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + } else if (real_c < 8 && real_c >= 4) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + int rc = real_c - 4; +#ifdef ENABLE_NEON + vst1_f16(dst_addr, vld1_f16(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + src_addr += 4; + dst_addr += 4; + for (int i = 0; i < rc; ++i) { + dst_addr[i] = src_addr[i]; + } + } + } + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } + } + } +} + +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * C8NUM; + size_t dst_step = in_channel * tile_num; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func) { + const int tile_num = 16; + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM; + size_t dst_step = input_unit * tile_num * C8NUM; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int plane = output_w * output_h; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + plane * j) * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack) { + // original weight format : ohwi + int oc_block_num = UP_DIV(filter_batch, oc_block); + int block_stride = filter_channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t))); + if (matrix_gt_data_fp16 == NULL) { + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit); + + // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T + // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T + float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data == NULL) { + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + +#ifndef ENABLE_ARM64 + float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * filter_channel; + for (int i = 0; i < filter_batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM64 + // tmp_data = g * GT + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // tmp_data1 = (tmp_data)T + PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit, + filter_channel); + // trans_out_data = (trans_out_data1)T + PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // trans = (tmp * GT)T + MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, + filter_channel); +#endif + + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < filter_channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += filter_channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data, + filter_channel * input_unit * input_unit * sizeof(float16_t)); + } + } + +#ifndef ENABLE_ARM64 + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + free(matrix_gt_data_fp16); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h new file mode 100644 index 00000000..3ace951b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_transform_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ +#define NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" +#include "nnacl_c/fp16/matrix_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func); + +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func); + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +// fp16 winograd weight trans +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c new file mode 100644 index 00000000..0f58223c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.c @@ -0,0 +1,3278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16/winograd_utils_fp16.h" +#include "nnacl_c/fp16/matrix_fp16.h" + +#define MIN_UNIT_FP16 2 +#define MAX_UNIT_FP16 4 + +#ifdef ENABLE_ARM64 +void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5, + float16x8_t *s6, float16x8_t *s7) { + float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1)); + float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1)); + float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3)); + float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3)); + float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5)); + float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5)); + float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7)); + float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7)); + + float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2)); + float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2)); + float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3)); + float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3)); + float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6)); + float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6)); + float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7)); + float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7)); + + *s0 = (float16x8_t)(vtrn1q_f64(t0, t4)); + *s4 = (float16x8_t)(vtrn2q_f64(t0, t4)); + *s1 = (float16x8_t)(vtrn1q_f64(t1, t5)); + *s5 = (float16x8_t)(vtrn2q_f64(t1, t5)); + *s2 = (float16x8_t)(vtrn1q_f64(t2, t6)); + *s6 = (float16x8_t)(vtrn2q_f64(t2, t6)); + *s3 = (float16x8_t)(vtrn1q_f64(t3, t7)); + *s7 = (float16x8_t)(vtrn2q_f64(t3, t7)); +} +#endif + +static InputTransFp16Func InputTransFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList4[] = {NULL, NULL, OutputTransform4x2UnitFp16, + OutputTransform4x3UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnitFp16, + OutputTransform4x3ReluUnitFp16}; +static OutputTransFp16Func OutputTransFp16FuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6UnitFp16, + OutputTransform4x3Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList6[] = {NULL, + NULL, + OutputTransform6x2UnitFp16, + OutputTransform6x3UnitFp16, + OutputTransform6x4UnitFp16, + OutputTransform6x5UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList6[] = {NULL, + NULL, + OutputTransform6x2ReluUnitFp16, + OutputTransform6x3ReluUnitFp16, + OutputTransform6x4ReluUnitFp16, + OutputTransform6x5ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List6[] = {NULL, + NULL, + OutputTransform6x2Relu6UnitFp16, + OutputTransform6x3Relu6UnitFp16, + OutputTransform6x4Relu6UnitFp16, + OutputTransform6x5Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList8[] = {NULL, + NULL, + OutputTransform8x2UnitFp16, + OutputTransform8x3UnitFp16, + OutputTransform8x4UnitFp16, + OutputTransform8x5UnitFp16, + OutputTransform8x6UnitFp16, + OutputTransform8x7UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList8[] = {NULL, + NULL, + OutputTransform8x2ReluUnitFp16, + OutputTransform8x3ReluUnitFp16, + OutputTransform8x4ReluUnitFp16, + OutputTransform8x5ReluUnitFp16, + OutputTransform8x6ReluUnitFp16, + OutputTransform8x7ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, + NULL, + OutputTransform8x2Relu6UnitFp16, + OutputTransform8x3Relu6UnitFp16, + OutputTransform8x4Relu6UnitFp16, + OutputTransform8x5Relu6UnitFp16, + OutputTransform8x6Relu6UnitFp16, + OutputTransform8x7Relu6UnitFp16}; + +InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFp16Func InputTransStepFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16}; + +static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL, + NULL, + NULL, + NULL, + InputTransform4x4Pack16Fp16, + NULL, + InputTransform6x6Pack16Fp16, + NULL, + InputTransform8x8Pack16Fp16}; + +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; } + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; } +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[16]; + float16x8_t t[16]; + float16x8_t m[16]; + Load16DataFp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsubq_f16(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsubq_f16(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[16]; + float16x4_t t[16]; + float16x4_t m[16]; + Load16DataC4Fp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsub_f16(src[offset], src[2 + offset]); + t[4 + l] = vadd_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsub_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsub_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsub_f16(t[offset], t[2 + offset]); + m[4 + l] = vadd_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsub_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsub_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[16]; + float16_t t[16]; + float16_t m[16]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 4; ++l) { + const float16_t *src_ptr = src_data + l * 4 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t m0 = vsubq_f16(s0, s2); + float16x8_t m1 = vaddq_f16(s1, s2); + float16x8_t m2 = vsubq_f16(s2, s1); + float16x8_t m3 = vsubq_f16(s3, s1); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + + float16x8_t m0 = vsubq_f16(s00, s20); + float16x8_t m1 = vsubq_f16(s01, s21); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(s10, s20); + m1 = vaddq_f16(s11, s21); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s20, s10); + m1 = vsubq_f16(s21, s11); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s30, s10); + m1 = vsubq_f16(s31, s11); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 4; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[36]; + float16x8_t t[36]; + float16x8_t m[36]; + Load36DataFp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); + float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), + vaddq_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), + vsubq_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + t[30 + l] = + vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); + float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[36]; + float16x4_t t[36]; + float16x4_t m[36]; + Load36DataC4Fp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(src[3 + offset], src[1 + offset]); + float16x4_t tmp2 = vsub_f16(src[4 + offset], src[2 + offset]); + t[l] = vadd_f16(vsub_f16(vmul_n_f16(src[offset], 4), vmul_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vadd_f16(vmul_n_f16(vadd_f16(src[1 + offset], src[2 + offset]), -4), + vadd_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(src[1 + offset], src[2 + offset]), 4), vsub_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + t[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + t[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(src[1 + offset], 4), vmul_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(t[3 + offset], t[1 + offset]); + float16x4_t tmp2 = vsub_f16(t[4 + offset], t[2 + offset]); + m[l] = vadd_f16(vsub_f16(vmul_n_f16(t[offset], 4), vmul_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vadd_f16(vmul_n_f16(vadd_f16(t[1 + offset], t[2 + offset]), -4), vadd_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(t[1 + offset], t[2 + offset]), 4), vsub_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + m[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + m[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(t[1 + offset], 4), vmul_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[36]; + float16_t t[36]; + float16_t m[36]; + for (int k = 0; k < 36; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = src[3 + offset] - src[1 + offset]; + float16_t tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = src[offset] * 4 - src[2 + offset] * 5 + src[4 + offset]; + t[6 + l] = (src[1 + offset] + src[2 + offset]) * -4 + (src[3 + offset] + src[4 + offset]); + t[12 + l] = (src[1 + offset] - src[2 + offset]) * 4 + (src[4 + offset] - src[3 + offset]); + t[18 + l] = tmp1 * 2 + tmp2; + t[24 + l] = tmp1 * -2 + tmp2; + t[30 + l] = src[1 + offset] * 4 - src[3 + offset] * 5 + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = t[3 + offset] - t[1 + offset]; + float16_t tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = t[offset] * 4 - t[2 + offset] * 5 + t[4 + offset]; + m[6 + l] = (t[1 + offset] + t[2 + offset]) * -4 + (t[3 + offset] + t[4 + offset]); + m[12 + l] = (t[1 + offset] - t[2 + offset]) * 4 + (t[4 + offset] - t[3 + offset]); + m[18 + l] = tmp1 * 2 + tmp2; + m[24 + l] = tmp1 * -2 + tmp2; + m[30 + l] = t[1 + offset] * 4 - t[3 + offset] * 5 + t[5 + offset]; + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 6; ++l) { + const float16_t *src_ptr = src_data + l * 6 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step); + float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step); + + float16x8_t tmp1 = vsubq_f16(s3, s1); + float16x8_t tmp2 = vsubq_f16(s4, s2); + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4); + float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4)); + float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3)); + float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + vst1q_f16(dst_ptr + 4 * dst_step, m4); + vst1q_f16(dst_ptr + 5 * dst_step, m5); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40); + float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40)); + m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41)); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31)); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50); + m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 6; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[64]; + float16x8_t t[64]; + float16x8_t m[64]; + Load64DataFp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), + vmulq_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), + vmulq_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), + vmulq_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), + vmulq_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[64]; + float16x4_t t[64]; + float16x4_t m[64]; + Load64DataC4Fp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(src[offset], 0.5625), vmul_n_f16(src[2 + offset], 3.0625)), + vmul_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 1.125), vmul_n_f16(src[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 2.25), vmul_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.5625), vmul_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.375), vmul_n_f16(src[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.25), vmul_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(src[1 + offset], -0.5625), vmul_n_f16(src[3 + offset], 3.0625)), + vmul_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(t[offset], 0.5625), vmul_n_f16(t[2 + offset], 3.0625)), + vmul_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 1.125), vmul_n_f16(t[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 2.25), vmul_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.5625), vmul_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.375), vmul_n_f16(t[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.25), vmul_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(t[1 + offset], -0.5625), vmul_n_f16(t[3 + offset], 3.0625)), + vmul_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[64]; + float16_t t[64]; + float16_t m[64]; + for (int k = 0; k < 64; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] * 0.5625f - src[2 + offset] * 3.0625f + src[4 + offset] * 3.5f - src[6 + offset]; + float16_t tmp1 = src[1 + offset] * 1.125f + src[5 + offset] * 0.5f; + float16_t tmp2 = src[2 + offset] * 2.25f - src[4 + offset] * 3.25f; + t[8 + l] = tmp1 + tmp2 - src[3 + offset] * 1.625f + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + src[3 + offset] * 1.625f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.5625f + src[5 + offset]; + tmp2 = src[2 + offset] * 0.5625f - src[4 + offset] * 2.5f; + t[24 + l] = tmp1 + tmp2 - src[3 + offset] * 2.5f + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + src[3 + offset] * 2.5f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.375f + src[5 + offset] * 1.5f; + tmp2 = src[2 + offset] * 0.25f - src[4 + offset] * 1.25f; + t[40 + l] = tmp1 + tmp2 - src[3 + offset] * 1.875f + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + src[3 + offset] * 1.875f + src[6 + offset]; + t[56 + l] = src[1 + offset] * -0.5625 + src[3 + offset] * 3.0625f - src[5 + offset] * 3.5f + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = t[offset] * 0.5625f - t[2 + offset] * 3.0625f + t[4 + offset] * 3.5f - t[6 + offset]; + float16_t tmp1 = t[1 + offset] * 1.125f + t[5 + offset] * 0.5f; + float16_t tmp2 = t[2 + offset] * 2.25f - t[4 + offset] * 3.25f; + m[8 + l] = tmp1 + tmp2 - t[3 + offset] * 1.625f + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + t[3 + offset] * 1.625f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.5625f + t[5 + offset]; + tmp2 = t[2 + offset] * 0.5625f - t[4 + offset] * 2.5f; + m[24 + l] = tmp1 + tmp2 - t[3 + offset] * 2.5f + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + t[3 + offset] * 2.5f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.375f + t[5 + offset] * 1.5f; + tmp2 = t[2 + offset] * 0.25f - t[4 + offset] * 1.25f; + m[40 + l] = tmp1 + tmp2 - t[3 + offset] * 1.875f + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + t[3 + offset] * 1.875f + t[6 + offset]; + m[56 + l] = t[1 + offset] * -0.5625 + t[3 + offset] * 3.0625f - t[5 + offset] * 3.5f + t[7 + offset]; + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform8x8StepFp16_uint(float16x8_t *s, float16x8_t *m) { + m[0] = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s[0], 0.5625), vmulq_n_f16(s[2], 3.0625)), vmulq_n_f16(s[4], 3.5)), s[6]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s[1], 1.125), vmulq_n_f16(s[5], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s[2], 2.25), vmulq_n_f16(s[4], 3.25)); + m[1] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.625)), s[6]); + m[2] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.625)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.5625), s[5]); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.5625), vmulq_n_f16(s[4], 2.5)); + m[3] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 2.5)), s[6]); + m[4] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 2.5)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.375), vmulq_n_f16(s[5], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.25), vmulq_n_f16(s[4], 1.25)); + m[5] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.875)), s[6]); + m[6] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.875)), s[6]); + m[7] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s[1], -0.5625), vmulq_n_f16(s[3], 3.0625)), vmulq_n_f16(s[5], 3.5)), + s[7]); +} + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 8; ++l) { + const float16_t *src_ptr = src_data + l * 8 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s[8]; + float16x8_t m[8]; + + s[0] = vld1q_f16(src_ptr + 0 * src_step); + s[1] = vld1q_f16(src_ptr + 1 * src_step); + s[2] = vld1q_f16(src_ptr + 2 * src_step); + s[3] = vld1q_f16(src_ptr + 3 * src_step); + s[4] = vld1q_f16(src_ptr + 4 * src_step); + s[5] = vld1q_f16(src_ptr + 5 * src_step); + s[6] = vld1q_f16(src_ptr + 6 * src_step); + s[7] = vld1q_f16(src_ptr + 7 * src_step); + + InputTransform8x8StepFp16_uint(s, m); + + vst1q_f16(dst_ptr + 0 * dst_step, m[0]); + vst1q_f16(dst_ptr + 1 * dst_step, m[1]); + vst1q_f16(dst_ptr + 2 * dst_step, m[2]); + vst1q_f16(dst_ptr + 3 * dst_step, m[3]); + vst1q_f16(dst_ptr + 4 * dst_step, m[4]); + vst1q_f16(dst_ptr + 5 * dst_step, m[5]); + vst1q_f16(dst_ptr + 6 * dst_step, m[6]); + vst1q_f16(dst_ptr + 7 * dst_step, m[7]); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + LOAD_LINE_DATA_FP16(6); + LOAD_LINE_DATA_FP16(7); + + float16x8_t m0 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60); + float16x8_t m1 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5)); + float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5)); + float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25)); + float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5)); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5)); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71); + vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 8; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) { + if (input_unit == 4 && output_unit < 4) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList4[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List4[output_unit]; + } else { + return OutputTransFp16FuncList4[output_unit]; + } + } else if (input_unit == 6 && output_unit < 6) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList6[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List6[output_unit]; + } else { + return OutputTransFp16FuncList6[output_unit]; + } + } else if (input_unit == 8 && output_unit < 8) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList8[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List8[output_unit]; + } else { + return OutputTransFp16FuncList8[output_unit]; + } + } else { + return NULL; + } +} + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + m[l + 2] = vmin_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] < 6 ? m[l] : 6; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + m[l + 2] = m[l + 2] < 6 ? m[l + 2] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 18] = vminq_f16(six, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 24] = vminq_f16(six, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + m[l + 30] = vminq_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 6] = vmin_f16(six, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 12] = vmin_f16(six, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 18] = vmin_f16(six, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 24] = vmin_f16(six, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + m[l + 30] = vmin_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 6] = m[l + 6] < 6 ? m[l + 6] : 6; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 12] = m[l + 12] < 6 ? m[l + 12] : 6; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 18] = m[l + 18] < 6 ? m[l + 18] : 6; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 24] = m[l + 24] < 6 ? m[l + 24] : 6; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + m[l + 30] = m[l + 30] < 6 ? m[l + 30] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 7] = vminq_f16(six, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 14] = vminq_f16(six, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 21] = vminq_f16(six, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 28] = vminq_f16(six, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 35] = vminq_f16(six, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + m[l + 42] = vminq_f16(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +int SelectOutputUnitFp16(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + int unit2 = UP_DIV(out_w * out_h, C16NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT_FP16 ? max_out_unit : MAX_UNIT_FP16; + max_out_unit = max_out_unit > MIN_UNIT_FP16 ? max_out_unit : MIN_UNIT_FP16; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT_FP16; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!GetOutputTransFp16Func(input_unit, i, ActType_No)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnitFp16(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + } + } else { + *use_winograd = false; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h new file mode 100644 index 00000000..40bfdce6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16.h @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_H_ +#define NNACL_FP16_WINOGRAD_UTILS_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16/winograd_utils_fp16_macro.h" + +#define MAX_LEN 256 + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFp16FuncList { + InputTransFp16Func in_func_; + InputTransStepFp16Func in_step_func_; + InputTransPackFp16Func in_pack_func_; + OutputTransFp16Func out_func_; +} TransFp16FuncList; + +InputTransFp16Func GetInputTransFp16Func(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit); + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit); +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnitFp16(const ConvParameter *conv_param); + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h new file mode 100644 index 00000000..defe39cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16/winograd_utils_fp16_macro.h @@ -0,0 +1,437 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ +#define NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define Load16DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); + +#define Load16DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); + +#define Load36DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); + +#define Load36DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); + +#define Load64DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); \ + src[36] = vld1q_f16(src_data + 36 * src_step); \ + src[37] = vld1q_f16(src_data + 37 * src_step); \ + src[38] = vld1q_f16(src_data + 38 * src_step); \ + src[39] = vld1q_f16(src_data + 39 * src_step); \ + src[40] = vld1q_f16(src_data + 40 * src_step); \ + src[41] = vld1q_f16(src_data + 41 * src_step); \ + src[42] = vld1q_f16(src_data + 42 * src_step); \ + src[43] = vld1q_f16(src_data + 43 * src_step); \ + src[44] = vld1q_f16(src_data + 44 * src_step); \ + src[45] = vld1q_f16(src_data + 45 * src_step); \ + src[46] = vld1q_f16(src_data + 46 * src_step); \ + src[47] = vld1q_f16(src_data + 47 * src_step); \ + src[48] = vld1q_f16(src_data + 48 * src_step); \ + src[49] = vld1q_f16(src_data + 49 * src_step); \ + src[50] = vld1q_f16(src_data + 50 * src_step); \ + src[51] = vld1q_f16(src_data + 51 * src_step); \ + src[52] = vld1q_f16(src_data + 52 * src_step); \ + src[53] = vld1q_f16(src_data + 53 * src_step); \ + src[54] = vld1q_f16(src_data + 54 * src_step); \ + src[55] = vld1q_f16(src_data + 55 * src_step); \ + src[56] = vld1q_f16(src_data + 56 * src_step); \ + src[57] = vld1q_f16(src_data + 57 * src_step); \ + src[58] = vld1q_f16(src_data + 58 * src_step); \ + src[59] = vld1q_f16(src_data + 59 * src_step); \ + src[60] = vld1q_f16(src_data + 60 * src_step); \ + src[61] = vld1q_f16(src_data + 61 * src_step); \ + src[62] = vld1q_f16(src_data + 62 * src_step); \ + src[63] = vld1q_f16(src_data + 63 * src_step); + +#define Load64DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); \ + src[36] = vld1_f16(src_data + 36 * src_step); \ + src[37] = vld1_f16(src_data + 37 * src_step); \ + src[38] = vld1_f16(src_data + 38 * src_step); \ + src[39] = vld1_f16(src_data + 39 * src_step); \ + src[40] = vld1_f16(src_data + 40 * src_step); \ + src[41] = vld1_f16(src_data + 41 * src_step); \ + src[42] = vld1_f16(src_data + 42 * src_step); \ + src[43] = vld1_f16(src_data + 43 * src_step); \ + src[44] = vld1_f16(src_data + 44 * src_step); \ + src[45] = vld1_f16(src_data + 45 * src_step); \ + src[46] = vld1_f16(src_data + 46 * src_step); \ + src[47] = vld1_f16(src_data + 47 * src_step); \ + src[48] = vld1_f16(src_data + 48 * src_step); \ + src[49] = vld1_f16(src_data + 49 * src_step); \ + src[50] = vld1_f16(src_data + 50 * src_step); \ + src[51] = vld1_f16(src_data + 51 * src_step); \ + src[52] = vld1_f16(src_data + 52 * src_step); \ + src[53] = vld1_f16(src_data + 53 * src_step); \ + src[54] = vld1_f16(src_data + 54 * src_step); \ + src[55] = vld1_f16(src_data + 55 * src_step); \ + src[56] = vld1_f16(src_data + 56 * src_step); \ + src[57] = vld1_f16(src_data + 57 * src_step); \ + src[58] = vld1_f16(src_data + 58 * src_step); \ + src[59] = vld1_f16(src_data + 59 * src_step); \ + src[60] = vld1_f16(src_data + 60 * src_step); \ + src[61] = vld1_f16(src_data + 61 * src_step); \ + src[62] = vld1_f16(src_data + 62 * src_step); \ + src[63] = vld1_f16(src_data + 63 * src_step); + +#define LOAD_LINE_DATA_FP16(line) \ + float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \ + float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile); + +#define TRANSPOSE_16x8 \ + float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \ + float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \ + float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \ + float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \ + float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \ + float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \ + float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \ + float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \ + float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \ + float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \ + float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \ + float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \ + float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \ + float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \ + float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \ + float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \ + transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \ + transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \ + vst1q_f16(src_ptr + 0 * pack_tile, s0); \ + vst1q_f16(src_ptr + 1 * pack_tile, s1); \ + vst1q_f16(src_ptr + 2 * pack_tile, s2); \ + vst1q_f16(src_ptr + 3 * pack_tile, s3); \ + vst1q_f16(src_ptr + 4 * pack_tile, s4); \ + vst1q_f16(src_ptr + 5 * pack_tile, s5); \ + vst1q_f16(src_ptr + 6 * pack_tile, s6); \ + vst1q_f16(src_ptr + 7 * pack_tile, s7); \ + vst1q_f16(src_ptr + 8 * pack_tile, s8); \ + vst1q_f16(src_ptr + 9 * pack_tile, s9); \ + vst1q_f16(src_ptr + 10 * pack_tile, s10); \ + vst1q_f16(src_ptr + 11 * pack_tile, s11); \ + vst1q_f16(src_ptr + 12 * pack_tile, s12); \ + vst1q_f16(src_ptr + 13 * pack_tile, s13); \ + vst1q_f16(src_ptr + 14 * pack_tile, s14); \ + vst1q_f16(src_ptr + 15 * pack_tile, s15); + +#define Store4DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + dst_step * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store4DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + dst_step * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store9DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store16DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + 4 * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1q_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#define Store25DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + 4 * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c new file mode 100644 index 00000000..f9f35b6d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.c @@ -0,0 +1,151 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_v = vdupq_n_f16(0); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src0_v = vld1q_f16(src0 + i); + float16x8_t src1_v = vld1q_f16(src1 + i); + uint16x8_t mask_v = vcleq_f16(src1_v, zero_v); + float16x8_t dst_v = vbslq_f16(mask_v, zero_v, src0_v); + vst1q_f16(dst + i, dst_v); + } +#endif + for (; i < length; i++) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_8 = vdupq_n_f16(0); + float16x8_t six_8 = vdupq_n_f16(6); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t src0_8 = vld1q_f16(src0 + i); + uint16x8_t gt_8 = vcgtq_f16(src1_8, zero_8); + uint16x8_t le_8 = vcleq_f16(src1_8, six_8); + uint16x8_t mask_8 = vandq_u16(gt_8, le_8); + float16x8_t dst_8 = vbslq_f16(mask_8, src0_8, zero_8); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 one_8 = vdupq_n_f16(1); + for (; i <= length - C8NUM; i += C8NUM) { + MS_FLOAT16X8 src0_8 = MS_LDQ_F16(src0 + i); + MS_FLOAT16X8 src1_8 = MS_LDQ_F16(src1 + i); + MS_STQ_F16(dst + i, vmulq_f16(src0_8, vmulq_f16(src1_8, (one_8 - src1_8)))); + } +#endif + for (; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one_8 = vdupq_n_f16(1); + for (; i < length - C8NUM; i += C8NUM) { + float16x8_t src0_8 = vld1q_f16(src0 + i); + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t dst_8 = vmulq_f16(src0_8, vmulq_f16(src1_8, vsubq_f16(one_8, src1_8))); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = (float16_t)((1.0f - ((float)src1[i] * (float)src1[i])) * (float)src0[i]); + } + return NNACL_OK; +} + +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x4_t zero_4 = vdup_n_f16(0); + float16x4_t one_4 = vdup_n_f16(1); + float16x4_t alpha_4 = vdup_n_f16(alpha); + for (; i <= length - C4NUM; i += C4NUM) { + float16x4_t src0_4 = vld1_f16(src0 + i); + float16x4_t src1_4 = vld1_f16(src1 + i); + uint16x4_t mask_4 = vcgt_f16(src1_4, zero_4); + float32x4_t tmp; + simd_exp128(vcvt_f32_f16(src1_4), (float *)&tmp); + float16x4_t expm1_4 = vsub_f16(vcvt_f16_f32(tmp), one_4); + float16x4_t dst_4 = vbsl_f16(mask_4, src0_4, vmul_f16(alpha_4, vmul_f16(expm1_4, src0_4))); + vst1_f16(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h new file mode 100644 index 00000000..99e7c463 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/activation_grad_fp16.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ +#define NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c new file mode 100644 index 00000000..40f63a9a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/arithmetic_grad.h" +#include +#include +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" + +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[i] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h new file mode 100644 index 00000000..77aad4cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_grad.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size); +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size); +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size); +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c new file mode 100644 index 00000000..5311f1f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/errorcode.h" + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t log_10 = vdupq_n_f16(log(10)); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10))); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * 1.0f / (src1[i] * log(10)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h new file mode 100644 index 00000000..7e4f6f4a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/arithmetic_self_grad.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" + +typedef struct ArithmeticSelfGradParameterFp16 { + OpParameter op_parameter; + int type_; +} ArithmeticSelfGradParameterFp16; +#ifdef __cplusplus +extern "C" { +#endif + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c new file mode 100644 index 00000000..33ac0542 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl_c/fp16_grad/batch_norm.h" + +void var2InvarFp16(float16_t *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = (float16_t)(1.0f / sqrtf((float)save_var[i] + eps)); + } +} + +void backwardAllFp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(size); + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += (float)(dx_hat * x_hat); + } + } + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + float16_t dx_hat = yt[ix] * scale[c]; + dx[ix] = (float16_t)((float)((invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} +void backwardP1Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += dx_hat * x_hat; + } + } +} + +void backwardP2Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int total_size, int ch, + const float *dxhat_sum, const float *dxhathat_sum, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + float dx_hat = (float)(yt[ix] * scale[c]); + dx[ix] = (float16_t)(((float)(invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h new file mode 100644 index 00000000..b744aa4d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/batch_norm.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_BATCH_NORM_H_ +#define NNACL_FP16_GRAD_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2InvarFp16(float16_t *save_var, int size, float eps); +void backwardAllFp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale, float16_t *dx); +void backwardP1Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale); +void backwardP2Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int total_size, int ch, const float *dxhat_sum, + const float *dxhathat_sum, float16_t *dx); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c new file mode 100644 index 00000000..d3b07044 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.c @@ -0,0 +1,361 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/convolution_grad_filter.h" +#include "nnacl_c/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef ENABLE_NEON + +static int FilterGrad32Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~31); i_c += 32) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + float32x4_t sum_4 = vdupq_n_f32(0.0f); + float32x4_t sum_5 = vdupq_n_f32(0.0f); + float32x4_t sum_6 = vdupq_n_f32(0.0f); + float32x4_t sum_7 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + + float16x8_t x_2 = vld1q_f16(x_addr + offset_x + 16); + float16x8_t dy_2 = vld1q_f16(dy_addr + offset_dy + 16); + sum_4 = MS_VMLAL_F16(vget_low_f16(x_2), vget_low_f16(dy_2), sum_4); + sum_5 = MS_VMLAL_F16(vget_high_f16(x_2), vget_high_f16(dy_2), sum_5); + + float16x8_t x_3 = vld1q_f16(x_addr + offset_x + 24); + float16x8_t dy_3 = vld1q_f16(dy_addr + offset_dy + 24); + sum_6 = MS_VMLAL_F16(vget_low_f16(x_3), vget_low_f16(dy_3), sum_6); + sum_7 = MS_VMLAL_F16(vget_high_f16(x_3), vget_high_f16(dy_3), sum_7); + } + } + } + // store into memory + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + 8 + l) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + 12 + l) * k_spatial + k_idx] = sum_3[l]; + dw[(i_c + 16 + l) * k_spatial + k_idx] = sum_4[l]; + dw[(i_c + 20 + l) * k_spatial + k_idx] = sum_5[l]; + dw[(i_c + 24 + l) * k_spatial + k_idx] = sum_6[l]; + dw[(i_c + 28 + l) * k_spatial + k_idx] = sum_7[l]; + } + } + return i_c; +} + +static int FilterGrad16Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + l + 4) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + l + 8) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + l + 12) * k_spatial + k_idx] = sum_3[l]; + } + } + return i_c; +} + +static int FilterGrad8Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~7); i_c += 8) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + } + } + return i_c; +} + +static int FilterGrad4Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~3); i_c += 4) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_0[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_0[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_0[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_0[3]; + } + return i_c; +} + +static int FilterGradLeftoverArm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int leftover = out_ch - i_c; + if (leftover > 0) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + for (int l = 0; l < leftover; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + } + } + return out_ch; +} + +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_NEON + i_c = FilterGrad32Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGradLeftoverArm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h new file mode 100644 index 00000000..ce3a413e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_filter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c new file mode 100644 index 00000000..4ff6af5e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.c @@ -0,0 +1,332 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/convolution_grad_input.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +static int ConvDwInputGrad16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C16NUM); j += C16NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C16NUM]; + for (int j_i = 0; j_i < C16NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; + float16x8_t mat_b1 = {mat_b[8][k], mat_b[9][k], mat_b[10][k], mat_b[11][k], + mat_b[12][k], mat_b[13][k], mat_b[14][k], mat_b[15][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + float16x4_t mat_b10; + float16x4_t mat_b11; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b10), "=w"(mat_b11) + : "r"(mat_b[8] + k), "r"(mat_b[9] + k), "r"(mat_b[10] + k), "r"(mat_b[11] + k), "r"(mat_b[12] + k), + "r"(mat_b[13] + k), "r"(mat_b[14] + k), "r"(mat_b[15] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); + float16x8_t mat_b1 = vcombine_f16(mat_b10, mat_b11); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + + float16x8_t mat_c1 = vld1q_f16(c + dx_offset + 8); + float16x8_t mat_a1 = vld1q_f16(a + dy_offset + 8); + mat_c1 = vfmaq_f16(mat_c1, mat_b1, mat_a1); + vst1q_f16(c + dx_offset + 8, mat_c1); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C16NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad8(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C8NUM); j += C8NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C8NUM]; + for (int j_i = 0; j_i < C8NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C8NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad4(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float16_t *c = dx + j; + const float16_t *mat_b_0 = w + (j + 0) * k_spatial; + const float16_t *mat_b_1 = w + (j + 1) * k_spatial; + const float16_t *mat_b_2 = w + (j + 2) * k_spatial; + const float16_t *mat_b_3 = w + (j + 3) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; +#else + float16x4_t mat_b; + asm volatile( + "vld1.16 %0[0], [%1]\n" + "vld1.16 %0[1], [%2]\n" + "vld1.16 %0[2], [%3]\n" + "vld1.16 %0[3], [%4]\n" + : "=w"(mat_b) + : "r"(mat_b_0 + k), "r"(mat_b_1 + k), "r"(mat_b_2 + k), "r"(mat_b_3 + k) + :); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x4_t mat_c = vld1_f16(c + dx_offset); + float16x4_t mat_a = vld1_f16(a + dy_offset); + mat_c = vfma_f16(mat_c, mat_b, mat_a); + vst1_f16(c + dx_offset, mat_c); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + c[dx_offset + 0] += a[dy_offset + 0] * mat_b_0[k]; + c[dx_offset + 1] += a[dy_offset + 1] * mat_b_1[k]; + c[dx_offset + 2] += a[dy_offset + 2] * mat_b_2[k]; + c[dx_offset + 3] += a[dy_offset + 3] * mat_b_3[k]; + } +#endif + } + } + } + } + return j; +} + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + j = ConvDwInputGrad16(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad8(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad4(dy, w, dx, j, end, conv_param); + for (; j < end; j++) { + float16_t *c = dx + j; + const float16_t *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + for (int bi = 0; bi < batch; bi++) { + c[bi * in_size + offset + 0] += a[0 + bi * out_size] * b[k]; + } + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h new file mode 100644 index 00000000..5e7c2485 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/convolution_grad_input.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c new file mode 100644 index 00000000..48695c53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.c @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/dropout_grad.h" + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h new file mode 100644 index 00000000..629f1996 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/dropout_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP16_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c new file mode 100644 index 00000000..408ef280 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.c @@ -0,0 +1,385 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/gemm_fp16.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl_c/fp16/matmul_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [dst_c] "r"(dst_ptr), [src_c] "r"(src_ptr), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void AddMatrixFp16(const float16_t *restrict v1, float16_t *restrict v2, float16_t beta, int row, int col, int stride) { + const float16_t *src_ptr = v1; + float16_t *dst_ptr = v2; +#ifdef ENABLE_NEON + float16x8_t beta_0 = vdupq_n_f16(beta); +#endif + for (int r = 0; r < row; r++) { + int c = 0; +#ifdef ENABLE_NEON + for (; c <= (col - C8NUM); c += C8NUM) { + float16x8_t dst_0 = vld1q_f16(dst_ptr + c); + float16x8_t src_0 = vld1q_f16(src_ptr + c); + float16x8_t sum_0 = vfmaq_f16(dst_0, beta_0, src_0); + vst1q_f16(dst_ptr + c, sum_0); + } +#endif + for (; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSizeFp16(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + return res; +} + +int MatSizeTotalFp16(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + int res = MatSizeFp16(row, deep, num) + MatSizeFp16(col, deep, C8NUM); + if (stride > 0) res += row * stride; + return res; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, stride); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C16NUM * stride; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * stride + c]; + } + } +} + +void RowMajor2Col12MajorStrideFp16(const float16_t *src, float16_t *dst, size_t row, size_t col, int stride) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, stride * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C12NUM * stride; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row12MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / C12NUM; + int c_mod12 = c % C12NUM; + dst[c_div12 * C12NUM * row + r * C12NUM + c_mod12] = src[r * stride + c]; + } + } +} + +static void RowMajor2Col8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / C8NUM; + int r_mod8 = r % C8NUM; + dst[r_div8 * C8NUM * col + c * C8NUM + r_mod8] = src[r * stride + c]; + } + } +} + +static void RowMajor2Row8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const float16_t *src_ptr = src + r * stride; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = src_ptr[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +static void RowMajor2ColXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Col12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +static void RowMajor2RowXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Row16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Row12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace) { + GemmCbFp16 gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + GemmMatmulPlusFp16(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + float16_t *output = mat_c; + float16_t *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float16_t *mat_a_input = (float16_t *)mat_a; + float16_t *mat_b_input = (float16_t *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + fworkspace += MatSizeFp16(M, K, num); + if (ta) { + RowMajor2RowXMajorStrideFp16(mat_a, mat_a_input, K, M, lda); + } else { + RowMajor2ColXMajorStrideFp16(mat_a, mat_a_input, M, K, lda); + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + fworkspace += MatSizeFp16(N, K, C8NUM); + if (tb) { + RowMajor2Col8MajorStrideFp16(mat_b, mat_b_input, N, K, ldb); + } else { + RowMajor2Row8MajorStrideFp16(mat_b, mat_b_input, K, N, ldb); + } + } + if (incremental) output = fworkspace; + MatMulFp16(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); + if (incremental) AddMatrixFp16(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h new file mode 100644 index 00000000..3f964c27 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/gemm_fp16.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_GEMM_FP16_H_ +#define NNACL_FP16_GRAD_GEMM_FP16_H_ + +#include +#include "nnacl_c/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float16_t *bias; + float16_t *mat_a; + float16_t *mat_b; +} GemmCbFp16; + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace); +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb); +int MatSizeFp16(int row, int col, int round); +int MatSizeTotalFp16(int row, int col, int deep, int inc); +void AddMatrixFp16(const float16_t *v1, float16_t *v2, float16_t beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_GEMM_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c new file mode 100644 index 00000000..c7a819c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/layernorm_grad.h" +#include +#include + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db) { + // var is actually 1/sqrf(var)-> var^0.5 + NNACL_CHECK_ZERO_RETURN(block_size); + const float16_t *var_sqrt_rev = var; + for (size_t i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (size_t j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = (float16_t)dgamma; + db[i] = (float16_t)dbeta; + } + for (size_t i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t dxm = x[j] - mean[norm_shift]; + float16_t dyg = dy[j] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t var_sqrt = var_sqrt_rev[norm_shift]; + float dx1 = dy[j] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[j] = (float16_t)(dx1 + dx2 + dx3); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h new file mode 100644 index 00000000..1eca819d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/layernorm_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c new file mode 100644 index 00000000..a2c5b47b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.c @@ -0,0 +1,201 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp16_grad/pack_fp16_ext.h" + +void RollingIm2ColPackDwUnitFp16(const float16_t *in_data, const ConvParameter *conv_param, float16_t *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float16_t *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *packed_input = input_data[offset]; + packed_input++; + } else { + *packed_input = 0; + packed_input++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(packed_input, input_data + offset, sizeof(float16_t) * channels); + packed_input += channels; + } else { + memset(packed_input, 0, sizeof(float16_t) * channels); + packed_input += channels; + } + } + } + } + } +} + +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h new file mode 100644 index 00000000..0d6d6841 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pack_fp16_ext.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_PACK_FP16_EXT_H_ +#define NNACL_FP16_GRAD_PACK_FP16_EXT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_PACK_FP16_EXT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c new file mode 100644 index 00000000..4ab7f96b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.c @@ -0,0 +1,192 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl_c/fp16_grad/pooling_grad.h" +#include "nnacl_c/op_base.h" + +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float16_t kk = 1.0f / (float16_t)(win_h * win_w); +#if ENABLE_NEON + const float16x4_t factor = vdup_n_f16(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + float16x4_t in = vld1_f16(inPtr + idx); + float16x4_t delta = vmul_f16(in, factor); +#else + float16_t delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_NEON + float16_t *out_vec = out + (xw + in_w * xh) * channel + ic; + float16x4_t outr = vld1_f16(out + (xw + in_w * xh) * channel + ic); + float16x4_t outs = vadd_f16(outr, delta); + vst1_f16(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float16_t *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_NEON +static int32x4_t MaxIndex(float16x4_t in, float16x4_t *max, uint32x4_t index, uint32x4_t prev_index) { + uint16x4_t res = vcgt_f16(in, *max); + int16x4_t tmp = vreinterpret_s16_u16(res); + uint32x4_t res_tmp = vreinterpretq_u32_s32(vmovl_s16(tmp)); + int32x4_t m_index = vbslq_s32(res_tmp, (int32x4_t)index, (int32x4_t)prev_index); + *max = vbsl_f16(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + for (int ib = 0; ib < output_batch; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; + const float16_t *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < (channel & ~3); ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t max_idx = vdupq_n_u32(0); + float16x4_t max_val = vdup_n_f16(-FLT16_MAX); + float16x4_t delta = vld1_f16(dyPtr + idx); +#else + float16_t delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float16_t max_val[C4NUM] = {-FLT16_MAX, -FLT16_MAX, -FLT16_MAX, -FLT16_MAX}; + uint max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float16x4_t in = vld1_f16(inPtr + val_idx); + max_idx = (uint32x4_t)MaxIndex(in, &max_val, index, max_idx); +#else + float16_t val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float16_t *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float16_t max_val = -FLT16_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_e; kw < kw_s; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float16_t val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h new file mode 100644 index 00000000..5c2bdee2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/pooling_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_POOLING_GRAD_H_ +#define NNACL_FP16_GRAD_POOLING_GRAD_H_ + +#include "nnacl_c/fp16/pooling_fp16.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c new file mode 100644 index 00000000..6806293a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/resize_grad.h" +#include +#include "nnacl_c/infer/common_infer.h" + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (int32_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + float16_t in_y = (float16_t)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + + float16_t in_x = (float16_t)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + + for (size_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float16_t in_y = (float16_t)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float16_t in_x = (float16_t)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h new file mode 100644 index 00000000..e25fd0b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/resize_grad.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP16_GRAD_RESIZE_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ResizeFp16GradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float16_t height_scale_; + float16_t width_scale_; +} ResizeFp16GradParameter; + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c new file mode 100644 index 00000000..77582468 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.c @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp16_grad/strided_slice_grad.h" +#include "nnacl_c/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[(i + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_7D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + int *s = param->strides_; + int *b = param->begins_; + for (int i = 0; i < DIMENSION_7D; i++) { + size *= param->in_shape_[i]; + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C6NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C5NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C4NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C3NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C2NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C1NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C0NUM, C6NUM, pos); + + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] + + (o * s[C6NUM] + b[C6NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h new file mode 100644 index 00000000..6a79e8e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/strided_slice_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c new file mode 100644 index 00000000..8f794d42 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp16_grad/unsorted_segment_sum.h" +#include "nnacl_c/errorcode.h" + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1) { + NNACL_CHECK_ZERO_RETURN_ERR(input_dim1); + for (int i = 0; i < unit_num; ++i) { + int j = i / input_dim1; + int k = i % input_dim1; + + int index = indices[j]; + if (index < 0 || index >= output_dim0) { + continue; + } + int output_index = index * output_dim1 + k; + output[output_index] += input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h new file mode 100644 index 00000000..85a54ab2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp16_grad/unsorted_segment_sum.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ +#define NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c new file mode 100644 index 00000000..f284b75a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.c @@ -0,0 +1,292 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/activation_fp32_simd.h" + +int Fp32Relu(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Int32Relu(const int32_t *src, int length, int32_t *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Fp32Relu6(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu6, i, src, length, dst); + + for (; i < length; ++i) { + if (src[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src[i] > 6.0f ? 6.0f : src[i]; // relu 6.0 + } + } + return NNACL_OK; +} + +int Fp32Clip(const float *src, int length, float *dst, float min, float max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int LRelu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(LRelu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); + } + return NNACL_OK; +} + +int Sigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Sigmoid, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = 1.0f / (1.0f + dst[i]); + } + return NNACL_OK; +} + +float TanhOpt(float src) { + if (src > 5.0) { // src > 5.0, tanh(src) = 1.0f + return 1.0f; + } else if (src < -5.0) { // src < -5.0, tanh(src) = -1.0f + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int Tanh(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Tanh, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = TanhOpt(src[i]); + } + return NNACL_OK; +} + +int Swish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Swish, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = src[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HSwish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSwish, i, src, length, dst); + + for (; i < length; ++i) { + float in = src[i]; + float relu6 = MSMIN(MSMAX(in + C3NUM, 0), C6NUM); + dst[i] = in * relu6 / C6NUM; + } + return NNACL_OK; +} + +int HSigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSigmoid, i, src, length, dst); + + for (; i < length; ++i) { + float relu6 = MSMIN(MSMAX(src[i] + C3NUM, 0), C6NUM); + dst[i] = relu6 / C6NUM; + } + return NNACL_OK; +} + +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMin, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + SIMD_RUN_NO_SCALAR(HardTanhLimitMinMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int Gelu(const float *src, int length, float *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + SIMD_RUN_NO_SCALAR(GeluTanhApproximate, i, src, length, dst); + + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { + SIMD_RUN_NO_SCALAR(GeluErfAPPROXIMATE, i, src, length, dst); + + for (; i < length; i++) { + dst[i] = + 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); // dst = 0.5 * x * (1.0 + x / 1.4142135623730951f)) + } + } + return NNACL_OK; +} + +int Softplus(const float *src, int length, float *dst) { + float log_max = 88.0; + int i = 0; + + SIMD_RUN_NO_SCALAR(Softplus, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(src[i], dst + i); + if (src[i] > log_max) { + dst[i] = src[i]; + } else { + dst[i] = log1p(dst[i]); + } + } + return NNACL_OK; +} + +int Elu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Elu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} + +void Celu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Celu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i] / alpha) * alpha); + } + return; +} + +int HardShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(HardShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = src[i] >= neg_lambd && src[i] <= lambd ? 0 : src[i]; + } + return NNACL_OK; +} + +int SoftShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(SoftShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src[i] > lambd) ? (src[i] - lambd) : ((src[i] < neg_lambd) ? (src[i] + lambd) : (0)); + } + return NNACL_OK; +} + +int SoftsignFp32Opt(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(SoftsignFp32Opt, i, src, length, dst); + for (; i < length; ++i) { + dst[i] = src[i] / (1.0 + fabsf(src[i])); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h new file mode 100644 index 00000000..28a3da19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Fp32Relu(const float *src, int length, float *dst); +int Int32Relu(const int32_t *src, int length, int32_t *dst); +int Fp32Relu6(const float *src, int length, float *dst); +int Fp32Clip(const float *src, int length, float *dst, float min, float max); +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max); +int LRelu(const float *src, int length, float *dst, float alpha); +int Sigmoid(const float *src, int length, float *dst); +int Tanh(const float *src, int length, float *dst); +int HSigmoid(const float *src, int length, float *dst); +int Swish(const float *src, int length, float *dst); +int HSwish(const float *src, int length, float *dst); +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); +int Gelu(const float *src, int length, float *dst, bool approximate); +int Softplus(const float *src, int length, float *dst); +int Elu(const float *src, int length, float *dst, float alpha); +void Celu(const float *src, int length, float *dst, float alpha); +float TanhOpt(float src); +int HardShrink(const float *src, int length, float *dst, float lambd); +int SoftShrink(const float *src, int length, float *dst, float lambd); +int SoftsignFp32Opt(const float *src, int length, float *dst); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in new file mode 100644 index 00000000..8e9c48bc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/activation_fp32_simd.h.in @@ -0,0 +1,289 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32Relu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_F32(SIMD_LD_F32(src + index), zero)); + } + return index; +} + +static inline int Int32Relu@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst) { + SIMD_EPI32 zero = SIMD_MOV_EPI32(0.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(src + index), zero)); + } + return index; +} + +static inline int Fp32Relu6@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 six = SIMD_MOV_F32(6.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), zero, six)); + } + return index; +} + +static inline int Fp32Clip@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min, float max) { + SIMD_F32 min_val = SIMD_MOV_F32(min); + SIMD_F32 max_val = SIMD_MOV_F32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int Int32Clip@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst, int min, int max) { + SIMD_EPI32 min_val = SIMD_MOV_EPI32(min); + SIMD_EPI32 max_val = SIMD_MOV_EPI32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_CLAMP_EPI32(SIMD_LD_EPI32(src + index), min_val, max_val)); + } + return index; +} + +static inline int LRelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + SIMD_F32 alpha_data = SIMD_MOV_F32(alpha); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_MASK mask = SIMD_CMPGT_F32(SIMD_SET0_F32, src_tmp); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_F32(src_tmp, alpha_data), mask)); + } + return index; +} + +static inline int Sigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, (SIMD_LD_F32(src + index))), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(SIMD_MOV_F32(1.0f), SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int Softplus@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 log_max = SIMD_MOV_F32(88.0); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 dst_tmp = SIMD_EXP_F32(src_tmp); + dst_tmp = SIMD_LOG_F32(SIMD_ADD_F32(SIMD_MOV_F32(1.0f), dst_tmp)); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(dst_tmp, src_tmp, SIMD_CMPGT_F32(src_tmp, log_max))); + } + return index; +} + +static inline int Tanh@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + index); + SIMD_ST_F32(dst + index, SIMD_TANH_F32(input)); + } + return index; +} + +static inline int Swish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, src_value), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(src_value, SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int HSwish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(SIMD_MUL_F32(src_value, relu6), 6)); + } + return index; +} + +static inline int HSigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(relu6, 6)); + } + return index; +} + +static inline int HardTanhNoLimitMin@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MIN_N_F32(SIMD_LD_F32(src + index), max_val)); + } + return index; +} + +static inline int HardTanhNoLimitMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_N_F32(SIMD_LD_F32(src + index), min_val)); + } + return index; +} + +static inline int HardTanhLimitMinMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_N_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int GeluTanhApproximate@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 tmp1 = SIMD_FMADD_F32(SIMD_MUL_N_F32(in, 0.035677408136f), in, SIMD_MOV_F32(0.79788456080287f)); + SIMD_F32 tmp2 = SIMD_MUL_F32(tmp1, in); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_MUL_N_F32(in, 0.5f), SIMD_ADD_N_F32(SIMD_TANH_F32(tmp2), 1.0f))); + } + return index; +} + +static inline int Gelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_F32(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline SIMD_F32 SIMD_ERFCCHEB@SIMD_INSTRUCTION@(SIMD_F32 src) { + static const int ncof = 7; + const double cof[7] = {-1.3026537197817094, 6.4196979235649026e-1, 1.9476473204185836e-2, -9.561514786808631e-3, + -9.46595344482036e-4, 3.66839497852761e-4, 4.2523324806907e-5}; + SIMD_F32 dst; + SIMD_F32 d = SIMD_SET0_F32; + SIMD_F32 dd = SIMD_SET0_F32; + SIMD_F32 t = SIMD_DIV_F32(SIMD_MOV_F32(2.0f), SIMD_ADD_F32(src, SIMD_MOV_F32(2.0f))); + SIMD_F32 ty = SIMD_SUB_F32(SIMD_MUL_F32(SIMD_MOV_F32(4.0f), t), SIMD_MOV_F32(2.0f)); + + for (int j = ncof - 1; j > 0; j--) { + SIMD_F32 tmp = d; + d = SIMD_SUB_F32(SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[j])), dd); + dd = tmp; + } + + dst = + SIMD_FMADD_F32(src, src, MS_FSMUL_F32(dd, SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[0])), SIMD_MOV_F32(0.5f))); + dst = SIMD_MUL_F32(t, SIMD_EXP_F32(SIMD_MUL_F32(SIMD_MOV_F32(-1.0f), dst))); + return dst; +} + +static inline SIMD_F32 SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_F32 src) { + SIMD_F32 abs_src = SIMD_ABS_F32(src); + SIMD_F32 sign = SIMD_GETSIGN_F32(src); + SIMD_F32 dst = SIMD_ERFCCHEB@SIMD_INSTRUCTION@(abs_src); + return SIMD_MUL_F32(sign, SIMD_SUB_F32(SIMD_MOV_F32(1.0f), dst)); +} + +static inline int GeluErfAPPROXIMATE@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline int Elu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(src_tmp), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int Celu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(SIMD_DIV_N_F32(src_tmp, alpha)), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int HardShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & in */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), src_t); + /* v1 = (in < -lamdb) & in */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), src_t); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & (in - lamdb) */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), SIMD_SUB_F32(src_t, pos_lamdb_v)); + /* v1 = (in < -lamdb) & (in + lamdb) */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), SIMD_ADD_F32(src_t, pos_lamdb_v)); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftsignFp32Opt@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 divisor_tmp = SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_ABS_F32(src_tmp)); + SIMD_ST_F32(dst + index, SIMD_DIV_F32(src_tmp, divisor_tmp)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c new file mode 100644 index 00000000..499957c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.c @@ -0,0 +1,239 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32/adam_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/avx512/adam_fp32_avx512.h" +#endif + +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1 - beta1); + __m256 coeff2_r = _mm256_set1_ps(1 - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 lr_r = _mm256_set1_ps(lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *var_ptr = var + start; + float *m_ptr = m + start; + float *v_ptr = v + start; + const float *grad_ptr = gradient + start; + + __m256 avx_r0, avx_r1; + __m256 var_r, m_r, v_r, grad_r; + + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(grad_ptr); + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_sub_ps(grad_r, m_r); + avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r); + m_r = _mm256_add_ps(m_r, avx_r1); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r); + v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r)); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r1 = _mm256_mul_ps(lr_r, avx_r0); + avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0); + + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1); + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + var_ptr += C8NUM; + grad_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * (1 - beta1); + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2); + if (use_nesterov) { + var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1.0f - beta1); + __m256 coeff2_r = _mm256_set1_ps(1.0f - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 beta2_r = _mm256_set1_ps(beta2); + __m256 lr_r = _mm256_set1_ps(-lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *m_ptr = m + start; + float *v_ptr = v + start; + float *delta_ptr = delta + start; + const float *gradient_ptr = gradient + start; + + __m256 m_r, v_r, delta_r, grad_r; + __m256 avx_r0, avx_r1; + for (; c1 < start + c8; c1 += C8NUM) { + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_mul_ps(m_r, beta1_r); + grad_r = _mm256_loadu_ps(gradient_ptr); + m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r)); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_mul_ps(v_r, beta2_r); + avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r); + v_r = _mm256_add_ps(avx_r0, avx_r1); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r0 = _mm256_mul_ps(lr_r, avx_r0); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + delta_ptr += C8NUM; + gradient_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; ++c1) { + m[c1] *= beta1; + m[c1] += (1 - beta1) * gradient[c1]; + v[c1] *= beta2; + v[c1] += (1 - beta2) * gradient[c1] * gradient[c1]; + if (use_nesterov) { + delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end); + + // remaining + const float beta1_minus = 1 - beta1; + const float beta2_minus = 1 - beta2; + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * beta1_minus; + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus; + var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]); + } + return NNACL_OK; +} + +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp16, c1, var, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp32, c1, var, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp16, c1, var16, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp32, c1, var16, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) { + if ((1.f - beta1_power[0]) <= 0.0f) { + return NNACL_PARAM_INVALID; + } + if ((1.f - beta2_power[0]) < 0.0f) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + + float update_lr = learning_rate * sqrtf(1.f - beta2_power[0]) / (1.f - beta1_power[0]); + const float one_minus_beta1 = 1.f - beta1; + const float one_minus_beta2 = 1.f - beta2; + if (nesterov) { // Nadam + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (sqrtf(v[i]) + eps); + } + } else { + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * m[i] / (sqrtf(v[i]) + eps); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h new file mode 100644 index 00000000..4f2b0e98 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_FP32_H +#define MINDSPORE_NNACL_ADAM_FP32_H +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov); +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov); +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end); +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_FP32_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in new file mode 100644 index 00000000..806f71bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adam_fp32_simd.h.in @@ -0,0 +1,203 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ +#ifdef MS_SIMD_AVX512 + static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient + index); + + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_F32(var + index, var_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp32@SIMD_INSTRUCTION@(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp32@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c new file mode 100644 index 00000000..f3e16f5c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.c @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/add_fp32_simd.h" + +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] + in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] + in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementAdd(tile_in0, tile_in1, out, size); +} + +int ElementAdd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} + +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] + in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h new file mode 100644 index 00000000..783bb8a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADD_H_ +#define MINDSPORE_NNACL_FP32_ADD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementAdd(const float *in0, const float *in1, float *out, int size); +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementAddRelu(const float *in0, const float *in1, float *out, int size); +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size); +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in new file mode 100644 index 00000000..ea8f846f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/add_fp32_simd.h.in @@ -0,0 +1,153 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin0 = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin1 = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, + int size) { + SIMD_EPI32 vin0_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c new file mode 100644 index 00000000..b5ef4123 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/adder_fp32.h" +#include +#include +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c4div = c / 4, c4mod = c % 4; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c4div * deep * 4 + d * 4 + c4mod; + value += fabsf(a[ai] - b[bi]); + } + value = -value; + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } +} + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride) { +#ifdef ENABLE_ARM64 + AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride); +#else + Adder12x4(a, b, c, bias, act_type, deep, row, col, stride); +#endif +} + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { + int out_channel = conv_param->output_channel_; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + if (conv_param->thread_num_ == 0) { + return; + } +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int cal_num = C4NUM; +#else + const int cal_num = C12NUM; +#endif + int output_tile_count = UP_DIV(output_count, cal_num); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * cal_num; + int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; + float *gemm_input = packed_input + task_id * deep * cal_num; + float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; + size_t packed_input_size = deep * cal_num * sizeof(float); + memset(gemm_input, 0, packed_input_size); + memset(col_major_gemm_input, 0, packed_input_size); + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * cal_num * out_channel + out_batch_offset; + float *gemm_output = output_data + out_offset; +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); +#else + RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); +#endif + AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num, + out_channel, out_channel); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h new file mode 100644 index 00000000..ee59ec66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/adder_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADDER_H_ +#define MINDSPORE_NNACL_FP32_ADDER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride); +#endif + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride); + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADDER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c new file mode 100644 index 00000000..cb8f49d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.c @@ -0,0 +1,298 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arg_min_max_fp32.h" +#include + +#define ARG_MIN_MAX_FUNC(data_type) \ + int ArgCompareDesc32##data_type(const void *a, const void *b) { \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return 1; \ + } \ + if (b_value < a_value) { \ + return -1; \ + } \ + return 0; \ + } \ + int ArgCompareAsc32##data_type(const void *a, const void *b) { \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return -1; \ + } \ + if (b_value < a_value) { \ + return 1; \ + } \ + return 0; \ + } \ + \ + void ArgMaxTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MIN_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp > value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MAX_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp < value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim0##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { \ + for (int j = 0; j < in_shape[0]; ++j) { \ + int offset = param->in_strides_[0] * j + i; \ + param->arg_elements_[j].index_ = (uint32_t)j; \ + param->arg_elements_[j].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); \ + for (int j = 0; j < param->topk_; ++j) { \ + int out_offset = j * param->out_strides_[0] + i; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[j].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + int in_shape1 = in_shape[1]; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < param->in_strides_[1]; ++j) { \ + for (int k = 0; k < in_shape1; ++k) { \ + int offset = param->in_strides_[1] * k + in_dim0_offset + j; \ + param->arg_elements_[k].index_ = (uint32_t)k; \ + param->arg_elements_[k].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); \ + for (int k = 0; k < param->topk_; ++k) { \ + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[k].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim2##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < param->in_strides_[2]; ++k) { \ + for (int l = 0; l < in_shape2; ++l) { \ + int offset = param->in_strides_[2] * l + k + in_dim1_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[l].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim3##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + int in_shape3 = in_shape[3]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < in_shape2; ++k) { \ + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; \ + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; \ + for (int l = 0; l < in_shape3; ++l) { \ + int offset = l + in_dim2_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim2_offset + l; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = (int)(param->arg_elements_[l].index_); \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMax##data_type##32(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param) { \ + if (param->topk_ == 1) { \ + int pre_axis_count = 1; \ + int axis_count = 1; \ + int after_axis_count = 1; \ + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); \ + \ + if (param->get_max_) { \ + ArgMaxTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } else { \ + ArgMinTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } \ + return; \ + } \ + \ + COMPARE_FUNCTION compare_function = NULL; \ + if (param->get_max_) { \ + compare_function = ArgCompareDesc32##data_type; \ + } else { \ + compare_function = ArgCompareAsc32##data_type; \ + } \ + \ + switch (param->axis_) { \ + case 0: \ + ArgMinMaxDim0##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 1: \ + ArgMinMaxDim1##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 2: \ + ArgMinMaxDim2##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 3: \ + ArgMinMaxDim3##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + } \ + return; \ + } + +#define DATA_TYPE float +#define MIN_VALUE -FLT_MAX +#define MAX_VALUE FLT_MAX +#define UNION_DATA f_data_ +ARG_MIN_MAX_FUNC(Fp) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA + +#define DATA_TYPE int32_t +#define MIN_VALUE INT32_MIN +#define MAX_VALUE INT32_MAX +#define UNION_DATA i_data_ +ARG_MIN_MAX_FUNC(Int) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h new file mode 100644 index 00000000..8895eeab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arg_min_max_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FP32_ARG_MIN_MAX_FP32_H_ +#define FP32_ARG_MIN_MAX_FP32_H_ + +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp32(const float *input, void *output, float *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +void ArgMinMaxInt32(const int32_t *input, void *output, int32_t *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // FP32_ARG_MIN_MAX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c new file mode 100644 index 00000000..70d0a55a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arithmetic_compare_fp32.h" + +inline bool EqualFp32(float x, float y); +inline bool EqualBool(bool x, bool y); +inline bool NotEqualFp32(float x, float y); +inline bool LessFp32(float x, float y); +inline bool LessEqualFp32(float x, float y); +inline bool GreaterFp32(float x, float y); +inline bool GreaterEqualFp32(float x, float y); + +inline bool EqualInt32(int x, int y); +inline bool NotEqualInt32(int x, int y); +inline bool NotEqualInt64(int64_t x, int64_t y); +inline bool LessInt32(int x, int y); +inline bool LessEqualInt32(int x, int y); +inline bool GreaterInt32(int x, int y); +inline bool GreaterEqualInt32(int x, int y); + +bool EqualFp32(float x, float y) { return x == y; } +bool EqualBool(bool x, bool y) { return x == y; } +bool NotEqualFp32(float x, float y) { return x != y; } +bool LessFp32(float x, float y) { return x < y; } +bool LessEqualFp32(float x, float y) { return x <= y; } +bool GreaterFp32(float x, float y) { return x > y; } +bool GreaterEqualFp32(float x, float y) { return x >= y; } + +bool EqualInt32(int x, int y) { return x == y; } +bool NotEqualInt32(int x, int y) { return x != y; } +bool NotEqualInt64(int64_t x, int64_t y) { return x != y; } +bool LessInt32(int x, int y) { return x < y; } +bool LessEqualInt32(int x, int y) { return x <= y; } +bool GreaterInt32(int x, int y) { return x > y; } +bool GreaterEqualInt32(int x, int y) { return x >= y; } + +#define ELEMENT_COMPARE(input0, input1, output, element_size, compare_func) \ + do { \ + for (int i = 0; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[i]); \ + } \ + return NNACL_OK; \ + } while (0) + +#define ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, compare_func) \ + do { \ + int i = 0; \ + if (first_scalar) { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[0], input1[i]); \ + } \ + } else { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[0]); \ + } \ + } \ + return NNACL_OK; \ + } while (0) + +// equal: +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualFp32); +} + +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualBool); +} + +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualFp32); +} + +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualInt32); +} + +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualInt32); +} + +// not equal +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualFp32); +} + +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualFp32); +} + +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt32); +} + +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt32); +} + +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt64); +} + +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt64); +} + +// less +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessFp32); +} + +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessFp32); +} + +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessInt32); +} + +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessInt32); +} + +// less equal +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualFp32); +} + +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualFp32); +} + +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualInt32); +} + +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualInt32); +} + +// greater +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterFp32); +} + +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterFp32); +} + +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterInt32); +} + +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterInt32); +} + +// greater equal +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualFp32); +} + +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualFp32); +} + +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualInt32); +} + +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualInt32); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h new file mode 100644 index 00000000..5ffc7a48 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_compare_fp32.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size); +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c new file mode 100644 index 00000000..7cba904f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.c @@ -0,0 +1,482 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include +#include "nnacl_c/arithmetic_fp32_simd.h" + +#define ACCURACY_DATA 0.00000001 + +int ElementFloorMod(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorMod, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[i]) * in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum0, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i]; + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum1, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0]; + } + } + + return NNACL_OK; +} + +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[i] - (in0[i] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + return NNACL_OK; +} + +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[0] - (in0[0] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (; i < size; i++) { + int remainder = in0[i] - (in0[i] / in1[0]) * in1[0]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder; + } + } + + return NNACL_OK; +} + +int ElementMod(const float *in0, const float *in1, float *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = fmodf(in0[i], in1[i]); + } + return NNACL_OK; +} + +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = fmodf(in0[0], in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = fmodf(in0[index], in1[0]); + } + } + return NNACL_OK; +} + +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] % in1[i]; + } + return NNACL_OK; +} + +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + if (first_scalar) { + for (int index = 0; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index]); + out[index] = in0[0] % in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (int index = 0; index < size; index++) { + out[index] = in0[index] % in1[0]; + } + } + return NNACL_OK; +} + +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorDiv, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[i]); + } + return NNACL_OK; +} + +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum0, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[0] / in1[i]); + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum1, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[0]); + } + } + + return NNACL_OK; +} + +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementFloorDivInt, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum0, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[0] / in1[i]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum1, i, in0, in1, out, size); + + for (; i < size; i++) { + out[i] = in0[i] / in1[0]; + } + } + + return NNACL_OK; +} + +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementLogicalAnd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + SIMD_RUN_NO_SCALAR(ElementOptLogicalAnd, index, in0, in1, out, size, first_scalar); + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) & (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + + return NNACL_OK; +} + +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) { + int index = 0; +#ifdef ENABLE_NEON + float32x4_t vtrue = vdupq_n_f32(1); + float32x4_t vfalse = vdupq_n_f32(0); + uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1)); + uint32x4_t zeros = vdupq_n_u32(0); + for (; index <= size - 4; index += C4NUM) { + uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f32(out + index, vout); + } +#endif + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) | (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[index]); + } + return NNACL_OK; +} + +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)(in0[0] | in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[0]); + } + } + + return NNACL_OK; +} + +int ElementMaximum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximumInt, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimumInt, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum0, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[0] > input1[index] ? input1[index] : input0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum1, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[0] ? input1[0] : input0[index]; + } + } + + return NNACL_OK; +} + +int ElementMinimum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in1[index] : in0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in1[index] : in0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in1[0] : in0[index]; + } + } + + return NNACL_OK; +} + +#undef ACCURACY_DATA + +void TileOneDimensionFp32(const void *inPtr, void *outPtr, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + const float *inData = (const float *)inPtr; + float *outData = (float *)outPtr; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp32(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp32(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp32(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void AssignSubOpt(float *in0, const float *in1, size_t size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(AssignSubOpt, index, in0, in1, size); + + for (; index < size; index++) { + in0[index] = in0[index] - in1[index]; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h new file mode 100644 index 00000000..005c90a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_H_ +#define MINDSPORE_NNACL_ARITHMETIC_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/fp32/div_fp32.h" +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/fp32/squared_difference.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionFp32(const void *inData, void *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param); +/* logical and */ +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* logical or */ +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* max min */ +int ElementMaximum(const float *in0, const float *in1, float *out, int size); +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMinimum(const float *in0, const float *in1, float *out, int size); +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size); +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar); + +/* floor div */ +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* floor mod */ +int ElementFloorMod(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* mod */ +int ElementMod(const float *in0, const float *in1, float *out, int size); +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +void AssignSubOpt(float *in0, const float *in1, size_t size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in new file mode 100644 index 00000000..24688952 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_fp32_simd.h.in @@ -0,0 +1,287 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifndef MS_SIMD_NEON +static inline int ElementFloorMod@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementFloorDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, floor_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} +#endif + +static inline int ElementFloorDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline size_t AssignSubOpt@SIMD_INSTRUCTION@(int index, float *in0, const float *in1, size_t size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(in0 + index, out_tmp); + } + return index; +} + +int ElementLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +int ElementOptLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size, bool first_scalar) { + if (first_scalar) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(*in0); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } else { + SIMD_F32 in1_tmp = SIMD_MOV_F32(*in1); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c new file mode 100644 index 00000000..a85950c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.c @@ -0,0 +1,230 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/arithmetic_self_fp32_simd.h" + +int ElementAbs(const float *input, float *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbs, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbsInt, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = abs(input[i]); + } + return NNACL_OK; +} + +// cos +int ElementCos(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_X86_NO_SCALAR(ElementCos, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +// log: +int ElementLog(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementLog, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +// log1p: +int ElementLog1p(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < -1.0f)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = log1p(input[i]); + } + return NNACL_OK; +} + +int ElementSquare(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSquare, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementRsqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_RSQRT_NEGATIVE; + } + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSin(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNot(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = !input[i]; + } + return NNACL_OK; +} + +int ElementRound(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_AVX(ElementRound, i, input, output, element_size); + SIMD_RUN_SSE(ElementRound, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloor(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloor, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeil(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementCeil, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegative(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegative, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegativeInt, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocal(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementReciprocal, i, input, output, element_size); + for (; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +// Erf +int ElementErf(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} + +int ElementIsFinite(const float *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = true; + if (isnan(input[i]) || isinf(input[i])) { + output[i] = false; + } + } + return NNACL_OK; +} + +int ElementMish(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_NO_SCALAR(ElementMish, i, input, output, element_size); + + for (; i < element_size; ++i) { + simd_exp32(input[i], output + i); + float exp_pow = (output[i] + 1) * (output[i] + 1); + output[i] = input[i] * (exp_pow - 1) / (exp_pow + 1); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h new file mode 100644 index 00000000..5c4213c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbs(const float *input, float *output, const int element_size); +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementCos(const float *input, float *output, const int element_size); + +int ElementLog(const float *input, float *output, const int element_size); + +int ElementLog1p(const float *input, float *output, const int element_size); + +int ElementSquare(const float *input, float *output, const int element_size); + +int ElementSqrt(const float *input, float *output, const int element_size); + +int ElementRsqrt(const float *input, float *output, const int element_size); + +int ElementSin(const float *input, float *output, const int element_size); + +int ElementLogicalNot(const float *input, float *output, const int element_size); + +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size); + +int ElementRound(const float *input, float *output, const int element_size); + +int ElementFloor(const float *input, float *output, const int element_size); + +int ElementCeil(const float *input, float *output, const int number); + +int ElementNegative(const float *input, float *output, const int element_size); +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementReciprocal(const float *input, float *output, const int element_size); + +int ElementErf(const float *input, float *output, const int element_size); + +int ElementIsFinite(const float *input, bool *output, const int element_size); + +int ElementMish(const float *input, float *output, const int element_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in new file mode 100644 index 00000000..ec29d20f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/arithmetic_self_fp32_simd.h.in @@ -0,0 +1,152 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#if defined(MS_SIMD_AVX512) +// only avx512 support abs fp32 instruction +static inline int ElementAbs@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ABS_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementAbsInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ABS_EPI32(SIMD_LD_EPI32(input + index))); + } + return index; +} +#endif + +#if !defined(MS_SIMD_NEON) +// not support neon + static inline int ElementCos@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_COS_F32(vin)); + } + return index; + } + + static inline int ElementLog@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_LOG_F32(vin)); + } + return index; + } +#endif + +static inline int ElementSquare@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_MUL_F32(vin, vin)); + } + return index; +} + +static inline int ElementSqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_SQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementRsqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_RSQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementMish@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 one = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 exp_add_one = SIMD_ADD_F32(SIMD_EXP_F32(SIMD_LD_F32(input + index)), one); + SIMD_F32 exp_pow = SIMD_MUL_F32(exp_add_one, exp_add_one); + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), + SIMD_DIV_F32(SIMD_SUB_F32(exp_pow, one), SIMD_ADD_F32(exp_pow, one)))); + } + return index; +} + +#if defined(MS_SIMD_AVX) || defined(MS_SIMD_SSE) +// avx512 dont support round fp32 instruction +static inline int ElementRound@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ROUND_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +// neon dont support floor fp32 instruction +static inline int ElementFloor@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_FLOOR_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +static inline int ElementCeil@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_CEIL_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +static inline int ElementNegative@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_N_F32(SIMD_LD_F32(input + index), -1.0f)); + } + return index; +} + +static inline int ElementNegativeInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MUL_N_EPI32(SIMD_LD_EPI32(input + index), -1)); + } + return index; +} + +static inline int ElementReciprocal@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 num1 = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_DIV_F32(num1, SIMD_LD_F32(input + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c new file mode 100644 index 00000000..e0bf8f2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.c @@ -0,0 +1,581 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/attention_fp32.h" +#include +#include +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/errorcode.h" + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans) { + if (matrix == NULL) { + return NNACL_NULL_PTR; + } + matrix->batch_ = batch; + matrix->row_ = row; + matrix->col_ = col; + matrix->is_transpose_ = is_trans; + matrix->data_ = NULL; + matrix->packed_data_ = NULL; + return NNACL_OK; +} + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile) { + if (matrix == NULL) { + return 0; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int dst_area = row_align * deep; + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return matrix->batch_ * dst_area; +} + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile) { + if (matrix == NULL) { + return 0; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int dst_area = deep * col_align; + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return matrix->batch_ * dst_area; +} + +int PackLeftMatrix(Matrix *matrix, int row_tile) { + if (matrix == NULL || matrix->data_ == NULL || row_tile == 0) { + return NNACL_NULL_PTR; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = row_align * deep; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (row_tile) { + case C6NUM: + if (matrix->is_transpose_) { + RowMajor2Row6Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col6Major(cur_src, cur_dst, real_row, deep); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Row4Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col4Major(cur_src, cur_dst, real_row, deep); + } + break; + case C12NUM: + if (matrix->is_transpose_) { + RowMajor2Row12Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col12Major(cur_src, cur_dst, real_row, deep); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return NNACL_OK; +} + +int PackRightMatrix(Matrix *matrix, int col_tile) { + if (matrix == NULL || matrix->data_ == NULL || col_tile == 0) { + return NNACL_NULL_PTR; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = deep * col_align; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (col_tile) { + case C16NUM: + if (matrix->is_transpose_) { + RowMajor2Col16Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row16Major(cur_src, cur_dst, deep, real_col); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Col4Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row4Major(cur_src, cur_dst, deep, real_col); + } + break; + case C8NUM: + if (matrix->is_transpose_) { + RowMajor2Col8Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row8Major(cur_src, cur_dst, deep, real_col); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return NNACL_OK; +} + +int PackAttentionBias(Matrix *matrix, int tile) { + if (matrix == NULL || matrix->batch_ != 1 || matrix->row_ != 1 || matrix->data_ == NULL) { + return NNACL_PARAM_INVALID; + } + if (tile == 0) { + return NNACL_OK; + } + int size = matrix->col_; + float *src = matrix->data_; + int size_align = UP_ROUND(size, tile); + if (size_align <= 0) { + return NNACL_ERR; + } + matrix->packed_data_ = (float *)malloc(size_align * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + matrix->packed_row_ = matrix->row_; + matrix->packed_col_ = size_align; + memset(matrix->packed_data_, 0, size_align * sizeof(float)); + memcpy(matrix->packed_data_, src, size * sizeof(float)); + return NNACL_OK; +} + +static void RelativeShiftPad(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int out_area = row * (col + 1); + memset(output_data, 0, out_area * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } +} + +static void RelativeShiftSlice(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int begin = row; + memset(output_data, 0, row * row * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } +} + +static void RelativeShift(const Matrix *x, float *pad_buf, float *slice_buf) { + int x_area = x->row_ * x->col_; + int pad_area = x->row_ * (x->col_ + 1); + int slice_area = x->row_ * (x->col_ / 2); + int input_shape[] = {x->row_, x->col_}; + memset(slice_buf, 0, x->batch_ * x->row_ * (x->col_ / 2) * sizeof(float)); + for (int i = 0; i < x->batch_; i++) { + float *cur_x_data = x->data_ + i * x_area; + memset(pad_buf, 0, pad_area * sizeof(float)); + // pad: [row, col + 1] + RelativeShiftPad(cur_x_data, pad_buf, input_shape, 0, 1); + // reshape: [col + 1, row] + // slice last row: [col, row] + // reshape: [row, col] + // slice col -> [row, row + col / 2]: [row, col / 2] + float *cur_slice_data = slice_buf + i * slice_area; + RelativeShiftSlice(pad_buf, cur_slice_data, input_shape, 0, 1); + } +} + +static void ElementOptAddDiv(const float *input0, const float *input1, const float input2, float *output, + const int batch, const int area) { + int index = 0; + const float mul = 1 / input2; + for (int b = 0; b < batch; b++) { + const float *cur_input0 = input0 + b * area; + const float *cur_input1 = input1 + b * area; + float *cur_output = output + b * area; +#ifdef ENABLE_NEON + for (; index <= area - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(cur_input0 + index); + float32x4_t vin1 = vld1q_f32(cur_input1 + index); + float32x4_t vout = vaddq_f32(vin0, vin1); + vout = vmulq_n_f32(vout, mul); + vst1q_f32(cur_output + index, vout); + } +#endif + for (; index < area; index++) { + cur_output[index] += (cur_input0[index] + cur_input1[index]) * mul; + } + } +} + +static bool GetTransposeParameter(TransposeParameter *param, const int in_shape[], int in_shape_len, + const int out_shape[], int out_shape_len, const int perm[], int perm_len) { + param->num_axes_ = perm_len; + size_t shape_size = 1; + for (int i = 0; i < perm_len; i++) { + param->perm_[i] = perm[i]; + shape_size *= perm[i]; // check overflow + } + param->data_num_ = (int)shape_size; // check overflow + param->strides_[param->num_axes_ - 1] = 1; + param->out_strides_[param->num_axes_ - 1] = 1; + if (param->num_axes_ - 1 >= in_shape_len) { + return false; + } + if (param->num_axes_ - 1 >= out_shape_len) { + return false; + } + for (int i = param->num_axes_ - 2; i >= 0; i--) { + param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; + param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; + } + return true; +} + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // Q * WQ + int q_area = q_mat->packed_row_ * q_mat->packed_col_; + int wq_area = wq_mat->packed_row_ * wq_mat->packed_col_; + int q2wq_area = q2wq_mat->row_ * q2wq_mat->col_ * q2wq_mat->batch_ / param->batch_; + float *q2wq_data = q2wq_mat->data_; + memset(q2wq_data, 0, param->batch_ * q2wq_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_q = q_mat->packed_data_ + i * q_area; + float *cur_wq = wq_mat->packed_data_ + i * wq_area; + float *cur_q2wq = q2wq_data + i * q2wq_area; + MatMulOpt(cur_q, cur_wq, cur_q2wq, bq_mat->packed_data_, ActType_No, q_mat->col_, q_mat->row_, wq_mat->col_, + wq_mat->col_, OutType_Nhwc); + } + // transpose param init + TransposeParameter q_with_pos_trans_param; + int q_with_pos_trans_in_shape[] = {batch, param->q_seq_, num_heads, depth}; + int q_with_pos_trans_out_shape[] = {batch, num_heads, param->q_seq_, depth}; + int q_with_pos_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&q_with_pos_trans_param, q_with_pos_trans_in_shape, 4, q_with_pos_trans_out_shape, 4, + q_with_pos_perm, 4); + int q2wq_reshaped_area = q2wq_mat->row_ * q2wq_mat->col_; + // Q_WQ + POS_U + { + float *q_with_pu = q2wq_with_pos_mat->data_; + int q_with_pu_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pu, 0, q2wq_with_pos_mat->batch_ * q_with_pu_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pu = q_with_pu + i * q_with_pu_area; + ElementAdd(cur_qw, pu_mat->packed_data_, cur_q_with_pu, q_with_pu_area); + } + // Q_WITH_U perm [0,2,1,3] + float *q_with_pu_trans = q2wq_with_pu_trans_mat->data_; + size_t q_with_pu_trans_data_size = (size_t)(q2wq_with_pu_trans_mat->batch_) * + (size_t)(q2wq_with_pu_trans_mat->row_) * (size_t)(q2wq_with_pu_trans_mat->col_) * + sizeof(float); + memset(q_with_pu_trans, 0, q_with_pu_trans_data_size); + TransposeDimsFp32(q_with_pu, q_with_pu_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } + + // Q_WQ + POS_V + { + float *q_with_pv = q2wq_with_pos_mat->data_; + int q_with_pv_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pv, 0, q2wq_with_pos_mat->batch_ * q_with_pv_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pv = q_with_pv + i * q_with_pv_area; + ElementAdd(cur_qw, pv_mat->packed_data_, cur_q_with_pv, q_with_pv_area); + } + // Q_WITH_V perm [0,2,1,3] + float *q_with_pv_trans = q2wq_with_pv_trans_mat->data_; + size_t q_with_pv_trans_data_size = (size_t)(q2wq_with_pv_trans_mat->batch_) * + (size_t)(q2wq_with_pv_trans_mat->row_) * (size_t)(q2wq_with_pv_trans_mat->col_) * + sizeof(float); + memset(q_with_pv_trans, 0, q_with_pv_trans_data_size); + TransposeDimsFp32(q_with_pv, q_with_pv_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } +} + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // K * WK + int k_area = k_mat->packed_row_ * k_mat->packed_col_; + int wk_area = wk_mat->packed_row_ * wk_mat->packed_col_; + int k2wk_area = k2wk_mat->row_ * k2wk_mat->col_ * k2wk_mat->batch_ / param->batch_; + float *k2wk = k2wk_mat->data_; + memset(k2wk, 0, param->batch_ * k2wk_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_k = k_mat->packed_data_ + i * k_area; + float *cur_wk = wk_mat->packed_data_ + i * wk_area; + float *cur_k2wk = k2wk + i * k2wk_area; + MatMulOpt(cur_k, cur_wk, cur_k2wk, bk_mat->packed_data_, ActType_No, k_mat->col_, k_mat->row_, wk_mat->col_, + wk_mat->col_, OutType_Nhwc); + } + // K * WK perm [0,2,3,1] + float *k2wk_trans_data = k2wk_trans_mat->data_; + int k2wk_trans_area = k2wk_trans_mat->row_ * k2wk_trans_mat->col_; + memset(k2wk_trans_data, 0, k2wk_trans_mat->batch_ * k2wk_trans_area * sizeof(float)); + TransposeParameter k2wk_trans_param; + int k2wk_in_shape[] = {batch, param->k_seq_, num_heads, depth}; + int k2wk_out_shape[] = {batch, num_heads, depth, param->k_seq_}; + int k2wk_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&k2wk_trans_param, k2wk_in_shape, 4, k2wk_out_shape, 4, k2wk_perm, 4); + TransposeDimsFp32(k2wk, k2wk_trans_data, k2wk_out_shape, k2wk_trans_param.perm_, k2wk_trans_param.strides_, + k2wk_trans_param.out_strides_, k2wk_trans_param.num_axes_, 0, 1); +} + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // V * WV + int v_area = v_mat->packed_row_ * v_mat->packed_col_; + int wv_area = wv_mat->packed_row_ * wv_mat->packed_col_; + int v2wv_area = v2wv_mat->row_ * v2wv_mat->col_ * v2wv_mat->batch_ / param->batch_; + float *v2wv = v2wv_mat->data_; + memset(v2wv, 0, param->batch_ * v2wv_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_v = v_mat->packed_data_ + i * v_area; + float *cur_wv = wv_mat->packed_data_ + i * wv_area; + float *cur_v2wv = v2wv + i * v2wv_area; + MatMulOpt(cur_v, cur_wv, cur_v2wv, bv_mat->packed_data_, ActType_No, v_mat->col_, v_mat->row_, wv_mat->col_, + wv_mat->col_, OutType_Nhwc); + } + // V * WV perm [0,2,1,3] + float *v2wv_trans_data = v2wv_trans_mat->data_; + int v2wv_trans_area = v2wv_trans_mat->row_ * v2wv_trans_mat->col_; + memset(v2wv_trans_data, 0, v2wv_trans_mat->batch_ * v2wv_trans_area * sizeof(float)); + TransposeParameter v2wv_trans_param; + int v2wv_in_shape[] = {batch, param->v_seq_, num_heads, depth}; + int v2wv_out_shape[] = {batch, num_heads, param->v_seq_, depth}; + int v2wv_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&v2wv_trans_param, v2wv_in_shape, 4, v2wv_out_shape, 4, v2wv_perm, 4); + TransposeDimsFp32(v2wv, v2wv_trans_data, v2wv_out_shape, v2wv_trans_param.perm_, v2wv_trans_param.strides_, + v2wv_trans_param.out_strides_, v2wv_trans_param.num_axes_, 0, 1); +} + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + + // P * WP + int p_area = p_mat->packed_row_ * p_mat->packed_col_; + int wp_area = wp_mat->packed_row_ * wp_mat->packed_col_; + int p2wp_area = p2wp_mat->row_ * p2wp_mat->col_ * p2wp_mat->batch_ / param->batch_; + float *p2wp_data = p2wp_mat->data_; + memset(p2wp_data, 0, param->batch_ * p2wp_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_p = p_mat->packed_data_ + i * p_area; + float *cur_wp = wp_mat->packed_data_ + i * wp_area; + float *cur_p2wp = p2wp_data + i * p2wp_area; + MatMulOpt(cur_p, cur_wp, cur_p2wp, NULL, ActType_No, p_mat->col_, p_mat->row_, wp_mat->col_, wp_mat->col_, + OutType_Nhwc); + } + // P * WP perm [0,2,3,1] + float *p2wp_trans_data = p2wp_trans_mat->data_; + int p2wp_trans_area = p2wp_trans_mat->row_ * p2wp_trans_mat->col_; + memset(p2wp_trans_data, 0, p2wp_trans_mat->batch_ * p2wp_trans_area * sizeof(float)); + TransposeParameter p2wp_trans_param; + int p2wp_in_shape[] = {batch, param->p_seq_, num_heads, depth}; + int p2wp_out_shape[] = {batch, num_heads, depth, param->p_seq_}; + int p2wp_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&p2wp_trans_param, p2wp_in_shape, 4, p2wp_out_shape, 4, p2wp_perm, 4); + TransposeDimsFp32(p2wp_data, p2wp_trans_data, p2wp_out_shape, p2wp_trans_param.perm_, p2wp_trans_param.strides_, + p2wp_trans_param.out_strides_, p2wp_trans_param.num_axes_, 0, 1); +} + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int depth = d_model / num_heads; + + // pack Q_WITH_U as left_matrix + // since we malloc dst data, pack function can not be failed + (void)PackLeftMatrix(q2wq_with_pu_trans_mat, param->row_tile_); + // pack Q_WITH_V as left_matrix + (void)PackLeftMatrix(q2wq_with_pv_trans_mat, param->row_tile_); + // pack K * WK as right_matrix + (void)PackRightMatrix(k2wk_trans_mat, param->col_tile_); + // pack P * WP as right_matrix + (void)PackRightMatrix(p2wp_trans_mat, param->col_tile_); + + // q_with_pu * k = logits_with_u + MatMulOpt(q2wq_with_pu_trans_mat->packed_data_, k2wk_trans_mat->packed_data_, logits_with_u_mat->data_, NULL, + ActType_No, q2wq_with_pu_trans_mat->col_, logits_with_u_mat->row_, logits_with_u_mat->col_, + logits_with_u_mat->col_, OutType_Nhwc); + + // q_with_pv * p = logits_with_v + MatMulOpt(q2wq_with_pv_trans_mat->packed_data_, p2wp_trans_mat->packed_data_, logits_with_v_mat->data_, NULL, + ActType_No, q2wq_with_pv_trans_mat->col_, logits_with_v_mat->row_, logits_with_v_mat->col_, + logits_with_v_mat->col_, OutType_Nhwc); + // relative shift logits_with_v + float *pad_buf = logits_with_v_pad_mat->data_; + float *logits_with_v_shifted_data = logits_with_v_shifted_mat->data_; + RelativeShift(logits_with_v_mat, pad_buf, logits_with_v_shifted_data); + // logits = (logits_with_u + logits_with_v) / sqrt(depth) + float *logits_buffer = logits_mat->data_; + ElementOptAddDiv(logits_with_u_mat->data_, logits_with_v_shifted_data, 1 / sqrt(depth), logits_buffer, + logits_with_u_mat->batch_, logits_with_u_mat->row_ * logits_with_u_mat->col_); +} + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + float *logits_buffer = logits_mat->data_; + // softmax(logits) + SoftmaxLastAxis(logits_buffer, softmax_mat->data_, batch * num_heads * softmax_mat->row_, softmax_mat->col_); + + // logits * v + (void)PackLeftMatrix(softmax_mat, param->row_tile_); + (void)PackRightMatrix(v2wv_trans_mat, param->col_tile_); + int softmax_logits_area = softmax_mat->packed_row_ * softmax_mat->packed_col_; + int v2wv_area = v2wv_trans_mat->packed_row_ * v2wv_trans_mat->packed_col_; + int logits2v_area = logits2v_mat->row_ * logits2v_mat->col_; + float *logits2v_data = logits2v_mat->data_; + memset(logits2v_data, 0, logits2v_mat->batch_ * logits2v_area * sizeof(float)); + for (int i = 0; i < logits2v_mat->batch_; i++) { + float *cur_logits = softmax_mat->packed_data_ + i * softmax_logits_area; + float *cur_v2wv = v2wv_trans_mat->packed_data_ + i * v2wv_area; + float *cur_logits2v = logits2v_data + i * logits2v_area; + MatMulOpt(cur_logits, cur_v2wv, cur_logits2v, NULL, ActType_No, softmax_mat->col_, softmax_mat->row_, + v2wv_trans_mat->col_, v2wv_trans_mat->col_, OutType_Nhwc); + } + // multi_head output perm [0,2,1,3] + float *logits2v_trans_data = logits2v_trans_mat->data_; + int logits2v_trans_area = logits2v_trans_mat->row_ * logits2v_trans_mat->col_; + memset(logits2v_trans_data, 0, logits2v_trans_mat->batch_ * logits2v_trans_area * sizeof(float)); + TransposeParameter logits2v_trans_param; + int logits2v_trans_in_shape[] = {batch, num_heads, param->q_seq_, depth}; + int logits2v_trans_out_shape[] = {batch, param->q_seq_, num_heads, depth}; + int logits2v_trans_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&logits2v_trans_param, logits2v_trans_in_shape, 4, logits2v_trans_out_shape, 4, + logits2v_trans_perm, 4); + TransposeDimsFp32(logits2v_data, logits2v_trans_data, logits2v_trans_out_shape, logits2v_trans_param.perm_, + logits2v_trans_param.strides_, logits2v_trans_param.out_strides_, logits2v_trans_param.num_axes_, 0, + 1); + // concat = reshape [batch, -1, d_model] + logits2v_trans_mat->batch_ = batch; + logits2v_trans_mat->row_ = param->q_seq_; + logits2v_trans_mat->col_ = param->d_model_; + // * o + (void)PackLeftMatrix(logits2v_trans_mat, param->row_tile_); + int concat_out_area = logits2v_trans_mat->packed_row_ * logits2v_trans_mat->packed_col_; + int wo_area = wo_mat->packed_row_ * wo_mat->packed_col_; + int output_area = output_mat->row_ * output_mat->col_; + for (int i = 0; i < output_mat->batch_; i++) { + float *cur_concat_out = logits2v_trans_mat->packed_data_ + i * concat_out_area; + float *cur_wo = wo_mat->packed_data_ + i * wo_area; + float *cur_output = output_mat->data_ + i * output_area; + MatMulOpt(cur_concat_out, cur_wo, cur_output, bo_mat->packed_data_, ActType_No, logits2v_trans_mat->col_, + logits2v_trans_mat->row_, wo_mat->col_, wo_mat->col_, OutType_Nhwc); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h new file mode 100644 index 00000000..68f8b05f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/attention_fp32.h @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ +#define MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ + +#include "nnacl_c/attention_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct Matrix { + float *data_; + int row_; + int col_; + float *packed_data_; + int packed_row_; + int packed_col_; + int batch_; + bool is_transpose_; +} Matrix; + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans); + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile); + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile); + +int PackLeftMatrix(Matrix *matrix, int row_tile); + +int PackRightMatrix(Matrix *matrix, int col_tile); + +int PackAttentionBias(Matrix *matrix, int tile); + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat); + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat); + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat); + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat); + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat); + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c new file mode 100644 index 00000000..d6a0d53f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/batchnorm_fp32_simd.h" +#include "nnacl_c/kernel/fused_batch_norm.h" +#include "nnacl_c/tensor_c_utils.h" + +int FusedBatchNormEval(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + if (fused_batch_norm->trained_) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + (void)memcpy(fused_batch_norm->bn_.variance_, var_tensor->data_, NNACLGetSize(var_tensor)); + } + return NNACL_OK; +} + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum) { + BatchNormStruct *bn = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(bn); + if (virtual_batch_multiplier > 0) { + float new_momentum = (momentum < 0.0f) ? (bn->momentum_ / virtual_batch_multiplier) : momentum; + bn->momentum_ = new_momentum; + } + return; +} + +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + int cur_offset = completed_units * channel; + float epsilon = param->epsilon_; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(BatchNormFp32, c, unit_input, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + float epsilon = param->epsilon_; + int cur_offset = completed_units * channel; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(FusedBatchNormFp32, c, unit_input, scale, offset, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + float norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (isBatchNorm2d == false) ? N : ((N > 1.0f) ? (N - 1.0f) : 1.0f); + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (input[idx] - run_mean[c]) * (input[idx] - run_mean[c]); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = (run_var[c] / VNUB); + run_var[c] = (run_var[c] / VN); + save_mean[c] = momentum * save_mean[c] + (1.0f - momentum) * run_mean[c]; + save_var[c] = momentum * save_var[c] + (1.0f - momentum) * unbiased_var; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h new file mode 100644 index 00000000..f09f4453 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_BATCHNORM_FP32_H_ +#define NNACL_FP32_BATCHNORM_FP32_H_ + +#include "nnacl_c/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum); +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output); + +int FusedBatchNormEval(KernelBase *self); +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output); +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BATCHNORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in new file mode 100644 index 00000000..fbf5f039 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/batchnorm_fp32_simd.h.in @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *mean, + const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 output_data = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +static inline int FusedBatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *scale, + const float *offset, const float *mean, const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 scale_ = SIMD_LD_F32(scale + index); + SIMD_F32 offset_ = SIMD_LD_F32(offset + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_F32 output_data = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_), offset_); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h new file mode 100644 index 00000000..206d833a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in new file mode 100644 index 00000000..12e3fcdf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_logits_loss_fp32_simd.h.in @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BCEWithLogitLoss@SIMD_INSTRUCTION@(int index, const float *logits, const float *label, + const float *weight, const float *pos_weight, int length, bool reduction, float *output, + float *reduction_sum) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 ones = SIMD_MOV_F32(1.0f); + SIMD_F32 middle_output = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 logits_tmp = SIMD_LD_F32(logits + index); + SIMD_F32 label_tmp = SIMD_LD_F32(label + index); + SIMD_F32 weight_tmp = SIMD_LD_F32(weight + index); + SIMD_F32 pos_weight_tmp = SIMD_LD_F32(pos_weight + index); + SIMD_F32 neg_logits_tmp = SIMD_SUB_F32(zero, logits_tmp); + SIMD_F32 max_value = neg_logits_tmp; + max_value = SIMD_MAX_F32(max_value, zero); + SIMD_F32 neg_max_value = SIMD_SUB_F32(zero, max_value); + SIMD_F32 log_weight = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(pos_weight_tmp, ones), label_tmp), ones); + SIMD_F32 log_exp_value = + SIMD_LOG_F32(SIMD_ADD_F32(SIMD_HEXP_F32(neg_max_value), SIMD_HEXP_F32(SIMD_SUB_F32(neg_logits_tmp, max_value)))); + SIMD_F32 loss = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(ones, label_tmp), logits_tmp), + SIMD_MUL_F32(log_weight, SIMD_ADD_F32(log_exp_value, max_value))); + if (reduction) { + middle_output = SIMD_FMADD_F32(loss, weight_tmp, middle_output); + } else { + SIMD_ST_F32(output + index, SIMD_MUL_F32(loss, weight_tmp)); + } + } + if (reduction) { + *reduction_sum += SIMD_GET_SUM_F32(middle_output); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c new file mode 100644 index 00000000..5209ca41 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bce_with_loigts_loss_fp32.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/bce_with_logits_loss_fp32.h" +#include "nnacl_c/bce_with_logits_loss_fp32_simd.h" + +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum) { + int i = 0; + float simd_reduction_output = 0.0f; + SIMD_RUN_NO_SCALAR(BCEWithLogitLoss, i, logits, label, weight, pos_weight, length, reduction, output, + &simd_reduction_output); + for (; i < length; ++i) { + float logits_value = logits[i]; + float label_value = label[i]; + float weight_value = weight[i]; + float post_weight_value = pos_weight[i]; + float max_value = -logits_value; + max_value = max_value > 0.f ? max_value : 0.f; + float log_weight = (post_weight_value - 1.0f) * label_value + 1.0f; + float log_exp_value = logf(expf(-max_value) + expf(-logits_value - max_value)); + float loss = (1.0f - label_value) * logits_value + log_weight * (log_exp_value + max_value); + if (reduction) { + simd_reduction_output += loss * weight_value; + } else { + output[i] = loss * weight_value; + } + } + if (reduction) { + *reduction_sum = simd_reduction_output; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c new file mode 100644 index 00000000..076672a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.c @@ -0,0 +1,123 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/bias_add.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/bias_add_simd.h" + +void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) { + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByInnerCore, index, input, bias, output, num); + + for (; index < num; ++index) { + output[index] = input[index] + bias[index]; + } +} + +void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) { + float *output1 = output; + float *output2 = output + num; + float *output3 = output + num * 2; + float *output4 = output + num * 3; + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByBatchCore, index, input, bias, output1, output2, output3, output4, num); + + const float *input_data1 = input; + const float *input_data2 = input + num; + const float *input_data3 = input + num * 2; + const float *input_data4 = input + num * 3; + for (; index < num; ++index) { + output1[index] = input_data1[index] + bias[index]; + output2[index] = input_data2[index] + bias[index]; + output3[index] = input_data3[index] + bias[index]; + output4[index] = input_data4[index] + bias[index]; + } +} + +void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } + if (start_inner != 0) { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + int64_t step = C4NUM * inner_num; + for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) { + BiasAddByBatchCore(input, cur_bias, output, inner_num); + input += step; + output += step; + } + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + BiasAddByInnerCore(input, cur_bias, output, end_inner); +} + +void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } else { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner); + return; + } else { + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + } + BiasAddByInnerCore(input, bias, output, end_inner); +} + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority) { + if (inner_num == 0) { + return; + } + int64_t start_outer = start / inner_num; + int64_t start_inner = start % inner_num; + int64_t end_outer = end / inner_num; + int64_t end_inner = end % inner_num; + const float *cur_input = input + start; + float *cur_output = output + start; + + if (batch_priority) { + DoBiasAddByBatch(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } else { + DoBiasAddByInner(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h new file mode 100644 index 00000000..210b1768 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority); + +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in new file mode 100644 index 00000000..baa787c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/bias_add_simd.h.in @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BiasAddByInnerCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output, + int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(input + index); + SIMD_F32 vin1 = SIMD_LD_F32(bias + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(output + index, vout); + } + return index; +} + +static inline int BiasAddByBatchCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output1, + float *output2, float *output3, float *output4, int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_LDX4_F32(input_data, input + index, num); + SIMD_F32 bias_data = SIMD_LD_F32(bias + index); + SIMD_ST_F32(output1 + index, SIMD_ADD_F32(input_data1, bias_data)); + SIMD_ST_F32(output2 + index, SIMD_ADD_F32(input_data2, bias_data)); + SIMD_ST_F32(output3 + index, SIMD_ADD_F32(input_data3, bias_data)); + SIMD_ST_F32(output4 + index, SIMD_ADD_F32(input_data4, bias_data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_SIMD_H_ \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c new file mode 100644 index 00000000..070f439a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.c @@ -0,0 +1,77 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/cdist_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/cdist_fp32_simd.h" + +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(CdistTwoNormalOpt, i, a, b, &result, m); + + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x * x; + } + result = sqrtf(result); + *dst = result; + + return; +} + +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(CdistPNormalOpt, i, a, b, &result, m, p); + + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += powf(x, p); + } + result = powf(result, 1.0 / p); + *dst = result; + + return; +} + +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += MSMIN(ceilf(x), 1.0f); + } + *c = result; +} + +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x; + } + *c = result; +} + +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result = MSMAX(result, x); + } + *c = result; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h new file mode 100644 index 00000000..e7f408ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_H_ +#define MINDSPORE_NNACL_FP32_CDIST_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); + +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CDIST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in new file mode 100644 index 00000000..7e88ea1d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cdist_fp32_simd.h.in @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t CdistTwoNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + result_vec = SIMD_FMADD_F32(tmp_vec, tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t CdistPNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size, float p) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + SIMD_F32 p_vec = SIMD_MOV_F32(p); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + tmp_vec = SIMD_POW_F32(tmp_vec, p_vec); + result_vec = SIMD_ADD_F32(tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c new file mode 100644 index 00000000..ae9375ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.c @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/common_func_fp32.h" + +void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) { + if (size == 0) { + return; + } + for (size_t oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size; + int oc_mod = oc % size; + for (int hw = 0; hw < (int)plane_size; hw++) { + int src_index = oc_div * size * plane_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } +} + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type) { +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) + PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); +#else + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = stride * sizeof(float); + PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#endif +} + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type) { +#ifdef ENABLE_AVX + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = (plane_stride - plane_size) * C8NUM * sizeof(float); + WinogradPostFuncBiasReluC8(out_ptr, cx_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + size_t oc4mod = output_channel % C4NUM; + size_t oc4div = output_channel - oc4mod; + size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); + WinogradPostFuncBiasReluC4(out_ptr, cx_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type); +#else + PostConvFuncComm(cx_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, + C4NUM); +#endif +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + const float *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + float b = B[i * h + y]; + const float *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +// M = S * B , M = h * w * l, S = h * k * l, B = k * w +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + const float *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + const float *srcX = srcY + i * unitStep; + float b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h new file mode 100644 index 00000000..400a5f55 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/common_func_fp32.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +typedef struct ConvDwFp32BorderParam { + float *dst; + const float *src; + const float *weight; + const float *bias; + size_t height; + size_t width; + size_t in_kh_step; + size_t in_kw_step; + size_t kernel_w; + size_t relu; + size_t relu6; +} ConvDwFp32BorderParam; + +#ifdef __cplusplus +extern "C" { +#endif + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type); + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +#ifdef ENABLE_AVX +void ConvDwFp32Border(ConvDwFp32BorderParam *param); +#else +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); +#endif +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step); +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step); +#endif + +#ifdef ENABLE_ARM64 +void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); + +void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, + size_t relu6); + +void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); +#endif + +#ifdef __cplusplus +} +#endif +#endif /* MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h new file mode 100644 index 00000000..a58557df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/constant_of_shape_fp32.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c new file mode 100644 index 00000000..0c1cdd2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.c @@ -0,0 +1,1608 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void Conv1x1SW3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "vmovups (%7, %6, 1), %%ymm4\n" + "vmovups 0x20(%7, %6, 1), %%ymm5\n" + "vmovups 0x40(%7, %6, 1), %%ymm6\n" + "vmovups 0x60(%7, %6, 1), %%ymm7\n" + "vmovups (%7, %6, 2), %%ymm8\n" + "vmovups 0x20(%7, %6, 2), %%ymm9\n" + "vmovups 0x40(%7, %6, 2), %%ymm10\n" + "vmovups 0x60(%7, %6, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vbroadcastss (%0, %4), %%ymm14\n" + "vbroadcastss (%0, %4, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vbroadcastss 4(%0, %4), %%ymm14\n" + "vbroadcastss 4(%0, %4, 2), %%ymm15\n" + "vmovups 128(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 160(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 192(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 224(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 8(%0, %4), %%ymm14\n" + "vbroadcastss 8(%0, %4, 2), %%ymm15\n" + "vmovups 256(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 288(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 320(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 352(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vbroadcastss 12(%0, %4), %%ymm14\n" + "vbroadcastss 12(%0, %4, 2), %%ymm15\n" + "vmovups 384(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 416(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 448(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 480(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vbroadcastss 16(%0, %4), %%ymm14\n" + "vbroadcastss 16(%0, %4, 2), %%ymm15\n" + "vmovups 512(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 544(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 576(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 608(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 20(%0, %4), %%ymm14\n" + "vbroadcastss 20(%0, %4, 2), %%ymm15\n" + "vmovups 640(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 672(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 704(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 736(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vbroadcastss 24(%0, %4), %%ymm14\n" + "vbroadcastss 24(%0, %4, 2), %%ymm15\n" + "vmovups 768(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 800(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 832(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 864(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vbroadcastss 28(%0, %4), %%ymm14\n" + "vbroadcastss 28(%0, %4, 2), %%ymm15\n" + "vmovups 896(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 928(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 960(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 992(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + "vmovups %%ymm4, (%7, %6, 1)\n" + "vmovups %%ymm5, 0x20(%7, %6, 1)\n" + "vmovups %%ymm6, 0x40(%7, %6, 1)\n" + "vmovups %%ymm7, 0x60(%7, %6, 1)\n" + "vmovups %%ymm8, (%7, %6, 2)\n" + "vmovups %%ymm9, 0x20(%7, %6, 2)\n" + "vmovups %%ymm10, 0x40(%7, %6, 2)\n" + "vmovups %%ymm11, 0x60(%7, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vmovups 0x60(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 128(%1), %%ymm4\n" + "vmovups 160(%1), %%ymm5\n" + "vmovups 192(%1), %%ymm6\n" + "vmovups 224(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 256(%1), %%ymm4\n" + "vmovups 288(%1), %%ymm5\n" + "vmovups 320(%1), %%ymm6\n" + "vmovups 352(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vmovups 480(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 512(%1), %%ymm4\n" + "vmovups 544(%1), %%ymm5\n" + "vmovups 576(%1), %%ymm6\n" + "vmovups 608(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 640(%1), %%ymm4\n" + "vmovups 672(%1), %%ymm5\n" + "vmovups 704(%1), %%ymm6\n" + "vmovups 736(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 768(%1), %%ymm4\n" + "vmovups 800(%1), %%ymm5\n" + "vmovups 832(%1), %%ymm6\n" + "vmovups 864(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 896(%1), %%ymm4\n" + "vmovups 928(%1), %%ymm5\n" + "vmovups 960(%1), %%ymm6\n" + "vmovups 992(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm13", + "%ymm14"); +} + +void Conv1x1SW4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups 0x40(%8), %%ymm2\n" + "vmovups (%8, %7, 1), %%ymm3\n" + "vmovups 0x20(%8, %7, 1), %%ymm4\n" + "vmovups 0x40(%8, %7, 1), %%ymm5\n" + "vmovups (%8, %7, 2), %%ymm6\n" + "vmovups 0x20(%8, %7, 2), %%ymm7\n" + "vmovups 0x40(%8, %7, 2), %%ymm8\n" + "vmovups (%9), %%ymm9\n" + "vmovups 0x20(%9), %%ymm10\n" + "vmovups 0x40(%9), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vmovups 0x40(%1), %%ymm15\n" + "vbroadcastss (%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss (%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss (%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss (%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 96(%1), %%ymm13\n" + "vmovups 128(%1), %%ymm14\n" + "vmovups 160(%1), %%ymm15\n" + "vbroadcastss 4(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 4(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 4(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 4(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vmovups 256(%1), %%ymm15\n" + "vbroadcastss 8(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 8(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 8(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 8(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 288(%1), %%ymm13\n" + "vmovups 320(%1), %%ymm14\n" + "vmovups 352(%1), %%ymm15\n" + "vbroadcastss 12(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 12(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 12(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 12(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vmovups 448(%1), %%ymm15\n" + "vbroadcastss 16(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 16(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 16(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 16(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 480(%1), %%ymm13\n" + "vmovups 512(%1), %%ymm14\n" + "vmovups 544(%1), %%ymm15\n" + "vbroadcastss 20(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 20(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 20(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 20(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 576(%1), %%ymm13\n" + "vmovups 608(%1), %%ymm14\n" + "vmovups 640(%1), %%ymm15\n" + "vbroadcastss 24(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 24(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 24(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 24(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 672(%1), %%ymm13\n" + "vmovups 704(%1), %%ymm14\n" + "vmovups 736(%1), %%ymm15\n" + "vbroadcastss 28(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 28(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 28(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 28(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, 0x40(%8)\n" + "vmovups %%ymm3, (%8, %7, 1)\n" + "vmovups %%ymm4, 0x20(%8, %7, 1)\n" + "vmovups %%ymm5, 0x40(%8, %7, 1)\n" + "vmovups %%ymm6, (%8, %7, 2)\n" + "vmovups %%ymm7, 0x20(%8, %7, 2)\n" + "vmovups %%ymm8, 0x40(%8, %7, 2)\n" + "vmovups %%ymm9, (%9)\n" + "vmovups %%ymm10, 0x20(%9)\n" + "vmovups %%ymm11, 0x40(%9)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 96(%1), %%ymm4\n" + "vmovups 128(%1), %%ymm5\n" + "vmovups 160(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 192(%1), %%ymm4\n" + "vmovups 224(%1), %%ymm5\n" + "vmovups 256(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 288(%1), %%ymm4\n" + "vmovups 320(%1), %%ymm5\n" + "vmovups 352(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 480(%1), %%ymm4\n" + "vmovups 512(%1), %%ymm5\n" + "vmovups 544(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 576(%1), %%ymm4\n" + "vmovups 608(%1), %%ymm5\n" + "vmovups 640(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 672(%1), %%ymm4\n" + "vmovups 704(%1), %%ymm5\n" + "vmovups 736(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups (%8, %7, 1), %%ymm2\n" + "vmovups 0x20(%8, %7, 1), %%ymm3\n" + "vmovups (%8, %7, 2), %%ymm4\n" + "vmovups 0x20(%8, %7, 2), %%ymm5\n" + "vmovups (%9), %%ymm6\n" + "vmovups 0x20(%9), %%ymm7\n" + "vmovups (%9, %7, 1), %%ymm8\n" + "vmovups 0x20(%9, %7, 1), %%ymm9\n" + "vmovups (%9, %7, 2), %%ymm10\n" + "vmovups 0x20(%9, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "movq %0, %%rax\n" + "addq %5, %%rax\n" + + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vbroadcastss (%0), %%ymm14\n" + "vbroadcastss (%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss (%0, %4, 2), %%ymm14\n" + "vbroadcastss (%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 64(%1), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vbroadcastss 4(%0), %%ymm14\n" + "vbroadcastss 4(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 4(%0, %4, 2), %%ymm14\n" + "vbroadcastss 4(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 4(%%rax, %4), %%ymm14\n" + "vbroadcastss 4(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 128(%1), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vbroadcastss 8(%0), %%ymm14\n" + "vbroadcastss 8(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 8(%0, %4, 2), %%ymm14\n" + "vbroadcastss 8(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 8(%%rax, %4), %%ymm14\n" + "vbroadcastss 8(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vbroadcastss 12(%0), %%ymm14\n" + "vbroadcastss 12(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 12(%0, %4, 2), %%ymm14\n" + "vbroadcastss 12(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 12(%%rax, %4), %%ymm14\n" + "vbroadcastss 12(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 256(%1), %%ymm12\n" + "vmovups 288(%1), %%ymm13\n" + "vbroadcastss 16(%0), %%ymm14\n" + "vbroadcastss 16(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 16(%0, %4, 2), %%ymm14\n" + "vbroadcastss 16(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 16(%%rax, %4), %%ymm14\n" + "vbroadcastss 16(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 320(%1), %%ymm12\n" + "vmovups 352(%1), %%ymm13\n" + "vbroadcastss 20(%0), %%ymm14\n" + "vbroadcastss 20(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 20(%0, %4, 2), %%ymm14\n" + "vbroadcastss 20(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 20(%%rax, %4), %%ymm14\n" + "vbroadcastss 20(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm12\n" + "vmovups 416(%1), %%ymm13\n" + "vbroadcastss 24(%0), %%ymm14\n" + "vbroadcastss 24(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 24(%0, %4, 2), %%ymm14\n" + "vbroadcastss 24(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 24(%%rax, %4), %%ymm14\n" + "vbroadcastss 24(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 448(%1), %%ymm12\n" + "vmovups 480(%1), %%ymm13\n" + "vbroadcastss 28(%0), %%ymm14\n" + "vbroadcastss 28(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 28(%0, %4, 2), %%ymm14\n" + "vbroadcastss 28(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 28(%%rax, %4), %%ymm14\n" + "vbroadcastss 28(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, (%8, %7, 1)\n" + "vmovups %%ymm3, 0x20(%8, %7, 1)\n" + "vmovups %%ymm4, (%8, %7, 2)\n" + "vmovups %%ymm5, 0x20(%8, %7, 2)\n" + "vmovups %%ymm6, (%9)\n" // dst+3 + "vmovups %%ymm7, 0x20(%9)\n" + "vmovups %%ymm8, (%9, %7, 1)\n" + "vmovups %%ymm9, 0x20(%9, %7, 1)\n" + "vmovups %%ymm10, (%9, %7, 2)\n" + "vmovups %%ymm11, 0x20(%9, %7, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vmovups 96(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vmovups 160(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 256(%1), %%ymm13\n" + "vmovups 288(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 320(%1), %%ymm13\n" + "vmovups 352(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 448(%1), %%ymm13\n" + "vmovups 480(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + ic_align <<= 3; + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + float *dst_5 = dst + 5 * oc_align / sizeof(float); + float *dst_9 = dst + 9 * oc_align / sizeof(float); + asm volatile( + "movq %12, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups (%8, %7), %%ymm1\n" + "vmovups (%8, %7, 2), %%ymm2\n" + "vmovups (%9), %%ymm3\n" // dst_3 + "vmovups (%8, %7, 4), %%ymm4\n" + "vmovups (%10), %%ymm5\n" // dst_5 + "vmovups (%10, %7, 1), %%ymm6\n" + "vmovups (%10, %7, 2), %%ymm7\n" + "vmovups (%8, %7, 8), %%ymm8\n" + "vmovups (%11), %%ymm9\n" // dst_9 + "vmovups (%11, %7, 1), %%ymm10\n" + "vmovups (%11, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups (%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %0, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 2b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(dst_flag) // 12 + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x2, %%eax\n" + "je 0f\n" + "movq %0, %%rax\n" + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + : + : "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "a"(dst_flag) // 6 + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void Conv1x1SW1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 32(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm12", "%ymm13"); +} + +// sliding window to compate 1x1 conv in x86 +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param) { + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int ohw = output_h * output_w; + int ohw_step = UP_DIV(ohw, conv_param->thread_num_); + int ohw_start = ohw_step * task_id; + int ohw_end = MSMIN(ohw_start + ohw_step, ohw); + if (ohw_start >= ohw_end) { + return; + } + int act_type = C0NUM; + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + if (conv_param->act_type_ == ActType_Relu6) { + act_type += C1NUM; + } + if (conv_param->act_type_ == ActType_Relu6 || conv_param->act_type_ == ActType_Relu) { + act_type += C2NUM; + } + int pad_d = conv_param->pad_d_; + int pad_l = conv_param->pad_l_; + int pad_r = conv_param->pad_r_; + int pad_u = conv_param->pad_u_; + int oc_align = sw_param->block_channel_; + int oc_align_float = oc_align * sizeof(float); + int ic_align = sw_param->ic_align_; + int in_sw_step = sw_param->in_sw_step_; + int in_sw_step_float = sw_param->in_sw_step_ * sizeof(float); + int kernel_step = sw_param->kernel_step_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + const int ow_block_num[4] = {12, 6, 4, 3}; + const Conv1x1SWAVXKernel kernel[4][2] = {{Conv1x1SW1x8AVXKernel, Conv1x1SW12x8AVXKernel}, + {Conv1x1SW1x16AVXKernel, Conv1x1SW6x16AVXKernel}, + {Conv1x1SW1x24AVXKernel, Conv1x1SW4x24AVXKernel}, + {Conv1x1SW1x32AVXKernel, Conv1x1SW3x32AVXKernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + int ic_block = 128; + int dst_flag = 0; + for (int ic = 0; ic < ic_align; ic += ic_block) { + if (ic_align - ic <= ic_block) { + ic_block = ic_align - ic; + dst_flag = C3NUM - (ic == 0); + } else { + dst_flag = 1 - (ic == 0); + } + if (pad_d == 0 && pad_l == 0 && pad_r == 0 && pad_u == 0) { + const float *bias = bias_data; + int oc_block = 0; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + const float *weight = packed_weight + oc * kernel_step + ic * C8NUM * oc_block; + if (bias != NULL) { + bias = bias_data + oc * oc_tile_; + } + const float *src_w = input_data + ic + ohw_start * in_sw_step; + float *dst_oc = output_data + oc * oc_tile_; + int hw_block = ow_block_num[oc_block - 1]; + for (int hw = ohw_start; hw < ohw_end; hw += hw_block) { + if (hw_block > ohw_end - hw) { // ow is not enough and process one ow + hw_block = 1; + } + float *dst_w = dst_oc + hw * oc_align; + kernel[oc_block - 1][hw_block / ow_block_num[oc_block - 1]](dst_w, src_w, weight, bias, act_type, hw_block, + oc_block, oc_align_float, ic_block >> C3NUM, + in_sw_step_float, dst_flag); + src_w += hw_block * in_sw_step; + } + } + } + } + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + oc_align /= sizeof(float); + in_sw_step /= sizeof(float); + ic_align <<= C3NUM; + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < ow_block; ++i) { + if (dst_flag & 0x01) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(dst + i * oc_align + j * C8NUM); + } + } else { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + } + src_sw[i] = src + i * in_sw_step; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < ic_align; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_sw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (dst_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + } + _mm256_storeu_ps(dst + i * oc_align + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h new file mode 100644 index 00000000..6a948da3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_avx_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Conv1x1SWAVXKernel)(float *dst, const float *src, const float *weight, const float *bias, + size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, + size_t in_sw_step, size_t dst_flag); + +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag); +#endif +#ifdef __cplusplus +} +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h new file mode 100644 index 00000000..51ce48be --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_1x1_x86_fp32.h @@ -0,0 +1,21 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ + +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" + +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c new file mode 100644 index 00000000..47302410 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.c @@ -0,0 +1,435 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_common_fp32.h" +#include +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif +#include "nnacl_c/fp32/matmul_fp32.h" +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int out_w = conv_param->output_w_; + if (dilation_h == 0 || dilation_w == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} + +// fp32 conv common +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int output_hw = conv_param->output_h_ * conv_param->output_w_; +#ifdef ENABLE_AVX + Row2ColMajor = RowMajor2Col6Major; + const int cal_num = C6NUM; +#elif defined(ENABLE_SSE) + Row2ColMajor = RowMajor2Col4Major; + const int cal_num = C4NUM; +#elif defined(ENABLE_ARM64) + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + int cal_num = 0; + if (output_hw <= C4NUM) { + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + cal_num = C4NUM; + } else if (output_hw <= C8NUM) { + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + cal_num = C8NUM; + } else { + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + cal_num = C12NUM; + } +#elif defined(ENABLE_ARM32) + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#else + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#endif + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw + start_hw * out_channel; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +// fp32 conv common +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; +#ifdef ENABLE_AVX + const int cal_num = C6NUM; + Row2ColMajor = RowMajor2Col6Major; +#elif defined(ENABLE_SSE) + const int cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; +#elif defined(ENABLE_ARM64) + int cal_num = 0; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#elif defined(ENABLE_ARM32) + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#else + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#endif + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int input_hw = conv_param->input_h_ * conv_param->input_w_; + int in_channel = conv_param->input_channel_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int cal_num = 0; + int out_tile = 0; +#ifdef ENABLE_AVX + cal_num = C6NUM; + out_tile = C8NUM; + Row2ColMajor = RowMajor2Col6Major; + int align_channel = UP_DIV(out_channel, C16NUM) * C16NUM; +#else + out_tile = C4NUM; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#endif + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } +#ifdef ENABLE_AVX + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int out_stride = out_tile * cal_num; + int out_block_stride = output_hw * C8NUM; +#else + int out_stride = MSMIN(out_channel, out_tile) * cal_num; +#endif + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * in_channel * input_hw; +#ifdef ENABLE_AVX + int out_offset = b * align_channel * output_hw + start_hw * out_tile; +#else + int out_offset = b * out_channel * output_hw + start_hw * MSMIN(out_channel, out_tile); +#endif + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +#ifdef ENABLE_AVX + for (int oc = 0; oc < out_channel; oc += C16NUM) { + CommonConv6x16Kernel(gemm_output + oc * output_hw, col_major_input, packed_weight + oc * deep, bias_data + oc, + deep, out_block_stride, act_type, real_cal_row); + } +#else + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, output_hw, OutType_NC4HW4); +#endif + } + } +} +#endif + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t depth, + const size_t out_step, const size_t act_flag, const size_t real_cal_row) { +#define Store1 \ + _mm256_storeu_ps(dst, out[0]); \ + _mm256_storeu_ps(dst + out_step, out[1]); +#define Store2 \ + Store1 _mm256_storeu_ps(dst + C8NUM, out[2]); \ + _mm256_storeu_ps(dst + out_step + C8NUM, out[3]); +#define Store3 \ + Store2 _mm256_storeu_ps(dst + C16NUM, out[4]); \ + _mm256_storeu_ps(dst + out_step + C16NUM, out[5]); +#define Store4 \ + Store3 _mm256_storeu_ps(dst + C24NUM, out[6]); \ + _mm256_storeu_ps(dst + out_step + C24NUM, out[7]); +#define Store5 \ + Store4 _mm256_storeu_ps(dst + C32NUM, out[8]); \ + _mm256_storeu_ps(dst + out_step + C32NUM, out[9]); +#define Store6 \ + Store5 _mm256_storeu_ps(dst + C40NUM, out[10]); \ + _mm256_storeu_ps(dst + out_step + C40NUM, out[11]); + + __m256 out[12]; + if (bias != NULL) { + out[0] = _mm256_loadu_ps(bias); + out[1] = _mm256_loadu_ps(bias + C8NUM); + } else { + out[0] = _mm256_set1_ps(0.0f); + out[1] = _mm256_set1_ps(0.0f); + } + out[2] = out[0]; + out[3] = out[1]; + out[4] = out[0]; + out[5] = out[1]; + out[6] = out[0]; + out[7] = out[1]; + out[8] = out[0]; + out[9] = out[1]; + out[10] = out[0]; + out[11] = out[1]; + for (int d = 0; d < depth; ++d) { + __m256 w1 = _mm256_loadu_ps(weight); + __m256 w2 = _mm256_loadu_ps(weight + C8NUM); + __m256 s1 = _mm256_set1_ps(*src); + __m256 s2 = _mm256_set1_ps(*(src + 1)); + out[0] = _mm256_fmadd_ps(s1, w1, out[0]); + out[1] = _mm256_fmadd_ps(s1, w2, out[1]); + out[2] = _mm256_fmadd_ps(s2, w1, out[2]); + out[3] = _mm256_fmadd_ps(s2, w2, out[3]); + s1 = _mm256_set1_ps(*(src + 2)); + s2 = _mm256_set1_ps(*(src + 3)); + out[4] = _mm256_fmadd_ps(s1, w1, out[4]); + out[5] = _mm256_fmadd_ps(s1, w2, out[5]); + out[6] = _mm256_fmadd_ps(s2, w1, out[6]); + out[7] = _mm256_fmadd_ps(s2, w2, out[7]); + s1 = _mm256_set1_ps(*(src + 4)); + s2 = _mm256_set1_ps(*(src + 5)); + out[8] = _mm256_fmadd_ps(s1, w1, out[8]); + out[9] = _mm256_fmadd_ps(s1, w2, out[9]); + out[10] = _mm256_fmadd_ps(s2, w1, out[10]); + out[11] = _mm256_fmadd_ps(s2, w2, out[11]); + weight += C16NUM; + src += C6NUM; + } + __m256 six = _mm256_set1_ps(6.0f); + __m256 zero = _mm256_set1_ps(0.0f); + if (0x1 & act_flag) { // relu6 + out[0] = _mm256_min_ps(out[0], six); + out[1] = _mm256_min_ps(out[1], six); + out[2] = _mm256_min_ps(out[2], six); + out[3] = _mm256_min_ps(out[3], six); + out[4] = _mm256_min_ps(out[4], six); + out[5] = _mm256_min_ps(out[5], six); + out[6] = _mm256_min_ps(out[6], six); + out[7] = _mm256_min_ps(out[7], six); + out[8] = _mm256_min_ps(out[8], six); + out[9] = _mm256_min_ps(out[9], six); + out[10] = _mm256_min_ps(out[10], six); + out[11] = _mm256_min_ps(out[11], six); + } + if (0x2 & act_flag) { // relu + out[0] = _mm256_max_ps(out[0], zero); + out[1] = _mm256_max_ps(out[1], zero); + out[2] = _mm256_max_ps(out[2], zero); + out[3] = _mm256_max_ps(out[3], zero); + out[4] = _mm256_max_ps(out[4], zero); + out[5] = _mm256_max_ps(out[5], zero); + out[6] = _mm256_max_ps(out[6], zero); + out[7] = _mm256_max_ps(out[7], zero); + out[8] = _mm256_max_ps(out[8], zero); + out[9] = _mm256_max_ps(out[9], zero); + out[10] = _mm256_max_ps(out[10], zero); + out[11] = _mm256_max_ps(out[11], zero); + } + if (real_cal_row == C6NUM) { + Store6 + } else if (real_cal_row == C5NUM) { + Store5 + } else if (real_cal_row == C4NUM) { + Store4 + } else if (real_cal_row == C3NUM) { + Store3 + } else if (real_cal_row == C2NUM) { + Store2 + } else if (real_cal_row == C1NUM) { + Store1 + } +} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h new file mode 100644 index 00000000..35eb7ccd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_common_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_COMMON_H_ +#define MINDSPORE_NNACL_FP32_CONV_COMMON_H_ + +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int row, int col); +#ifdef ENABLE_ARM64 +typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type, + int depth, int row, int col, size_t stride, size_t write_mode); +#endif + +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +// fp32 convolution common (im2col+gemm) +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +// fp32 convolution common (im2col+gemm) +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param); + +// common convolution output C4HW4, if out_channel mod 4 remains, just output real channel, no zeros padded. +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t depth, + size_t out_step, size_t act_flag, size_t real_cal_row); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c new file mode 100644 index 00000000..55e11032 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + first_calc_flag = false; + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + ConvDwAVXFp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + weight_kh += conv_param->output_channel_; + first_calc_flag = false; + continue; + } + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + + ConvDwAVXFp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h new file mode 100644 index 00000000..5ab72316 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_avx_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVXFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c new file mode 100644 index 00000000..c8b7de86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.c @@ -0,0 +1,2074 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, + int output_channel, int input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *dw_weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwFp32Row(dst_w, src_kw, dw_weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + dw_weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX512 +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + first_calc_flag = false; + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + first_calc_flag = false; + weight_kh += conv_param->output_channel_; + continue; + } + + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } else if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} +#endif + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + if (block == 0) { + return; + } + int left = 0; + int right = conv_param->output_w_; + int top = 0; + int bottom = conv_param->output_h_; + + while (left * conv_param->stride_w_ < conv_param->pad_l_) { + left++; + } + while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ > + conv_param->input_w_ && + right > left) { + right--; + } + while (top * conv_param->stride_h_ < conv_param->pad_u_) { + top++; + } + while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ > + conv_param->input_h_ && + bottom > top) { + bottom--; + } + sliding->left_ = left; + sliding->right_ = right; + sliding->top_ = top; + sliding->bottom_ = bottom; + sliding->c_block_ = UP_DIV(conv_param->output_channel_, block); + sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block; + sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_; + if (conv_param->out_format_ == Format_NC4HW4) { + // write to nc8hw8 + sliding->out_h_step_ = conv_param->output_w_ * block; + sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_; + sliding->out_w_step_ = block; + sliding->out_block_step_ = sliding->out_c_step_; + } else { + // write to nhwc + sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_; + sliding->out_c_step_ = block; + sliding->out_w_step_ = sliding->block_channel_; + sliding->out_block_step_ = sliding->out_w_step_; + } +} + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + InitSlidingParam(sliding, conv_param, weight_block); + AppendSlidingParamConv(sliding, conv_param, input_block, weight_block); +} + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + if (input_block == 0) { // is not aligned + sliding->ic_align_ = conv_param->input_channel_; + } else { // 1x1 input is aligned to input_block + sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block; + } + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block; +} + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + InitSlidingParam(sliding, conv_param, block); + AppendSlidingParamConvDw(sliding, conv_param, block); +} + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; +} + +/*conv depthwise fp32 begin*/ +void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) { + const float *src_kh = src; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst[c] = 0; + } + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sliding->in_h_step_; + + float *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sliding->block_channel_; + + const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; +#ifdef ENABLE_AVX + ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam)); + if (param == NULL) { + return; + } + param->dst = dst_kernel; + param->src = src_kernel; + param->weight = weight_kernel; + param->bias = bias; + param->height = end_kh - start_kh; + param->width = end_kw - start_kw; + param->in_kh_step = sliding->in_kh_step_ * sizeof(float); + param->in_kw_step = sliding->in_kw_step_ * sizeof(float); + param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float); + param->relu = relu; + param->relu6 = relu6; + ConvDwFp32Border(param); + free(param); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); +#else + ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp32: sliding window +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + if (conv_param->thread_num_ == 0) { + return; + } + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_, + conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float), relu, relu6); +#else + ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu, + relu6); +#endif + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*conv depthwise fp32 end*/ + +/*conv depthwise 3x3 fp32 begin*/ +bool CheckConvDwUse3X3(const ConvParameter *conv_param) { + bool use_3x3 = + conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && + (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && + (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ && + (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && + conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1; + if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) { + return false; + } + const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_; + const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_; + return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) && + in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); +} + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + v0 = MS_MOVQ_F32(0.0f); + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v1 = MS_LDQ_F32(src + ic); + v2 = MS_LDQ_F32(src + channel + ic); + v3 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d1 = src[i + ic]; + float d2 = src[i + ic + channel]; + float d3 = src[i + ic + 2 * channel]; + remain_line[i] = 0.0f - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + v3 = MS_LDQ_F32(src + 3 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + float d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = 0.0f - d1; + } + } +} + +static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_STQ_F32(line + lw * ic, v0); + MS_STQ_F32(line + lw * ic + 4, v1); + MS_STQ_F32(line + lw * ic + 8, b2); + memset(line + lw * ic + 12, 0, 16); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 4] = d1; + remain_line[i + 8] = 0.0f - d1; + } + } +} + +static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeft(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Bottom(float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float)); +} + +#ifndef ENABLE_ARM64 +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6) { + int channel = ori_channel; + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + for (; channel > 0; channel -= 4) { + MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data); + bias_data += 4; + MS_FLOAT32X4 g00 = MS_LDQ_F32(weight); + MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4); + MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8); + MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12); + MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16); + MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20); + MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24); + MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28); + MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32); + MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36); + MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40); + MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44); + weight += 48; + float *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2)); + res0 = MS_ADDQ_F32(res0, bias); + res1 = MS_ADDQ_F32(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + MS_STQ_F32(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i); + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + res0 = MS_ADDQ_F32(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + } + } + } + dst += 4; + } +} +#endif + +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float *line0 = buffer; + float *line1 = buffer + units * c4 * C4NUM; + float *line2 = buffer + units * c4 * C8NUM; + float *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} +#endif + +/*conv depthwise indirect buffer fp32 begin*/ +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) { + bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) || + (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5); + return use_indirect; +} + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w) { +#ifdef ENABLE_AVX + int div = C8NUM; +#else + int div = C4NUM; +#endif + + int ic_div = UP_DIV(conv_param->input_channel_, div) * div; + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h; + float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div; + for (int oh = 0; oh < conv_param->output_h_; oh++) { + for (int kh = 0; kh < conv_param->kernel_h_; kh++) { + int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_; + if (ih < conv_param->input_h_ && ih >= 0) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_; + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + if (iw < conv_param->input_w_ && iw >= 0) { + indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div; + } else { + indirect[index] = zero_ptr; + } + } + } + } else { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + indirect[index] = zero_ptr; + } + } + } + } + } + } +} + +#if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX) +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + do { + float **in = input; + size_t c = (size_t)channels; + const float *w = weights; + float *out = output; + memcpy(out, bias, channels * (int)sizeof(float)); + for (; c >= C4NUM; c -= C4NUM) { + for (int i = 0; i < C4NUM; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + w += kernel * C4NUM; + out += C4NUM; + for (int k = 0; k < kernel; k++) { + in[k] += C4NUM; + } + } + for (int i = 0; i < c; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + if (relu) { + Fp32Relu(output, channels, output); + } + if (relu6) { + Fp32Relu6(output, channels, output); + } + output += channels; + input = input + input_stride; + } while (--output_width != 0); +} +#endif + +#ifdef ENABLE_ARM64 +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } else if (kernel == 25) { + ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } +} +#endif + +#ifdef ENABLE_AVX +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } else if (kernel == 25) { + ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } +} +#endif + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0) { + return; + } + int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_; + int step_h = + (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_; + int input_stride = conv_param->kernel_h_ * step_w; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h; + float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float **indirect = indirect_b + oh * step_h; + float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_; + if (conv_param->kernel_w_ == 3) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 9); + } else if (conv_param->kernel_w_ == 5) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 25); + } + } + } +} +/*conv depthwise indirect buffer fp32 end*/ + +/*deconv depthwise fp32 begin*/ +void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, + int in_kw_step, int kernel_w_step) { + float *dst_kh = dst; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right, + const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + const float *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float *dst_h = dst + oh * sliding->in_h_step_; + + const float *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float *dst_w = dst_h + ow * sliding->block_channel_; + + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float)); +#else + DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp32: sliding window +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + if (conv_param->thread_num_ == 0) { + return; + } + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_, + conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_, + conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float)); +#else + DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise fp32 end*/ + +#ifdef ENABLE_AVX +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) { + // dw border compate + int ih = top * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sw_param->in_h_step_; + float *dst_kernel = dst + left * sw_param->block_channel_; + for (int ow = left; ow < right; ow += ow_bock) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sw_param->block_channel_; + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block; + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, + oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_, + (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block); + dst_kernel += ow_bock * sw_param->block_channel_; + } // width loop +} + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) { + int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int oh_start = oh_step * task_id; + int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_); + if (oh_start >= oh_end) { + return; + } + // depthwise sw in x86 avx instructions + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int output_w = conv_param->output_w_; + int oc_algin = sw_param->block_channel_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + int in_sw_step = sw_param->in_sw_step_; + int in_kw_step = sw_param->in_kw_step_; + int in_kh_step = sw_param->in_kh_step_; + int in_sh_step = sw_param->in_sh_step_; + int out_right = sw_param->right_; + int out_left = sw_param->left_; + int out_top = sw_param->top_; + int out_bottom = sw_param->bottom_; + int kernel_step = sw_param->kernel_step_; + int out_h_step = sw_param->out_h_step_; + int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_; + int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin; + const int ow_block_num[4] = {8, 4, 4, 3}; + const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel}, + {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel}, + {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel}, + {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oh = oh_start; oh < oh_end; ++oh) { + float *dst_oh = output_data + oh * out_h_step; + const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step; + int oc_block = 0; + const float *bias = bias_data; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + int oc_step = oc * oc_tile_; + const float *weight = weight_data + oc * kernel_step; + if (bias != NULL) { + bias = bias_data + oc_step; + } + float *dst_w = dst_oh + oc_step; + const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0]; + if (oh < out_top || oh >= out_bottom) { // oh in up or down border + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + } else { // oh in center + // ow in right + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + // ow in center + const float *src_w = src_h + oc_step; + int ow_block = ow_block_num[oc_block - 1]; // 8 4 4 3 + for (int ow = out_left; ow < out_right; ow += ow_block) { // left ~ right + if (ow_block > out_right - ow) { // ow is not enough and process one ow + ow_block = 1; + } + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( + dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin, + in_kw_step, in_kh_step, in_sw_step, 0); + src_w += ow_block * in_sw_step; + } + // ow in left + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param, + sw_param, kernel_border, act_type, 1, oc_block); + } + } + } // output h loop + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + __m256 dst_data[12]; + __m256 src_data; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { // loop ow + for (int j = 0; j < oc_block; ++j) { + src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM); + dst_data[i * oc_block + j] += src_data * weight_data[j]; + } + } + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; // ic8 * dilation_w + } + weight_kernel += oc_block * C8NUM; + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; // + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7), %%ymm14\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm12\n" + "vmovups 0x60(%%rcx), %%ymm13\n" + "vmovups 0x60(%%rcx, %7), %%ymm14\n" + "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + "vmovups 0x60(%%rcx), %%ymm7\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x40(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n" + + "addq $96, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "addq $96, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n" + + "addq $64, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13", + "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "addq $64, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + float *dst_5 = dst + 5 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rax\n" + "vmovups (%1), %%ymm12\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + + "addq $32, %1\n" + "addq %4, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %5, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step), // 5 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", + "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + + "Write:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm4"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h new file mode 100644 index 00000000..988391ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_depthwise_fp32.h @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ENABLE_ARM64 +void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id); + +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block); + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int in_block, + int weight_block); + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +bool CheckConvDwUse3X3(const ConvParameter *conv_param); + +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param); + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w); + +#ifdef ENABLE_ARM64 +void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); +#endif + +#ifdef ENABLE_AVX +typedef void (*DepthwiseSWKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block); + +void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); +#endif +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6); +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel); + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id); + +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c new file mode 100644 index 00000000..308bae07 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.c @@ -0,0 +1,92 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/conv_im2col_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx512_instructions.h" + +// fp32 conv common +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw + start_hw * out_channel_align; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} + +// fp32 conv common +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h new file mode 100644 index 00000000..9991c8d7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_avx512_fp32.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num); + +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c new file mode 100644 index 00000000..9c5f8c92 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.c @@ -0,0 +1,65 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_im2col_fp32.h" + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_w = conv_param->kernel_w_; + int kernel_h = conv_param->kernel_h_; + int kernel_plane = kernel_h * kernel_w; + int dilation_w = conv_param->dilation_w_; + int dilation_h = conv_param->dilation_h_; + + int out_w = conv_param->output_w_; + if (dilation_w == 0 || dilation_h == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int input_stride = (input_h * in_w + input_w) * in_channel; + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h new file mode 100644 index 00000000..19ff0418 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_im2col_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ + +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h new file mode 100644 index 00000000..e95a6911 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw.h @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_H_ + +#define GenerateConvSWFunc(backend, oc_unit_num, row_num_list, kernel_list, compute_core, outer_compute) \ + void SWBorder##backend(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, \ + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, \ + const SWConvKernel kernel, int act_type, int ow_bock, int oc_block, size_t write_mode) { \ + for (int oh = top; oh < bottom; oh++) { \ + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; \ + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); \ + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); \ + const float *src_h = src + ih * sw_param->in_h_step_; \ + float *dst_kernel = dst + left * sw_param->out_w_step_; \ + for (int ow = left; ow < right; ow += ow_bock) { \ + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; \ + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); \ + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); \ + const float *src_w = src_h + iw * sw_param->ic_align_; \ + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; \ + const float *weight_kernel = \ + weight + (start_kh * conv_param->kernel_w_ + start_kw) * sw_param->ic_align_ * C8NUM * oc_block; \ + outer_compute dst_kernel += ow_bock * sw_param->out_w_step_; \ + } \ + dst += sw_param->out_h_step_; \ + } \ + } \ + \ + void ConvSW##backend##Fp32(const float *input_data, const float *packed_weight, const float *bias_data, \ + float *output_data, int task_id, ConvParameter *conv_param, \ + SlidingWindowParam *sw_param) { \ + int out_h = conv_param->output_h_; \ + int oh_step = UP_DIV(out_h, conv_param->thread_num_); \ + int oh_start = oh_step * task_id; \ + int oh_end = MSMIN(oh_start + oh_step, out_h); \ + if (oh_start >= oh_end) { \ + return; \ + } \ + int oc_tile_ = C8NUM; /* oc in algin to C8NUM in arm64 */ \ + int act_type = 0; \ + if (conv_param->act_type_ == ActType_Relu6) { \ + act_type += 1; \ + } \ + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { \ + act_type += 2; \ + } \ + int kernel_h = conv_param->kernel_h_; \ + int kernel_w = conv_param->kernel_w_; \ + int ic_algin = sw_param->ic_align_; \ + int in_sw_step = sw_param->in_sw_step_; \ + int in_kw_step = sw_param->in_kw_step_; \ + int in_kh_step = sw_param->in_kh_step_; \ + int in_sh_step = sw_param->in_sh_step_; \ + int out_h_step = sw_param->out_h_step_; \ + int out_c_step = sw_param->out_c_step_; \ + int out_w_step = sw_param->out_w_step_; \ + int out_block_step = sw_param->out_block_step_; \ + int kernel_step = sw_param->kernel_step_; \ + int in_step = sw_param->in_step_; \ + int out_step = sw_param->out_step_; \ + int c_block = sw_param->c_block_; \ + int top = sw_param->top_; \ + int left = sw_param->left_; \ + int right = sw_param->right_; \ + int bottom = sw_param->bottom_; \ + int stride_h = conv_param->stride_h_; \ + int stride_w = conv_param->stride_w_; \ + int out_w = conv_param->output_w_; \ + int pad_u = conv_param->pad_u_; \ + int pad_l = conv_param->pad_l_; \ + int in_h_step = sw_param->in_h_step_; \ + int out_batch = conv_param->output_batch_; \ + int in_h_start = top * stride_h - pad_u; \ + int in_w_start = left * stride_w - pad_l; \ + int center_step = in_h_start * in_h_step + in_w_start * ic_algin; \ + int write_mode = conv_param->out_format_; \ + row_num_list kernel_list for (int b = 0; b < out_batch; b++) { \ + for (int oh = oh_start; oh < oh_end; oh += 1) { \ + float *dst_oh = output_data + oh * out_h_step; \ + const float *src_h = input_data + center_step; \ + \ + int oc_block = 0; \ + const float *bias = bias_data; \ + for (int oc = 0; oc < c_block; oc += oc_block) { \ + oc_block = MSMIN(oc_unit_num, c_block - oc); \ + const float *weight = packed_weight + oc * kernel_step; \ + if (bias != NULL) { \ + bias = bias_data + oc * oc_tile_; \ + } \ + float *dst_oc = dst_oh + oc * out_c_step; \ + const SWConvKernel kernel_border = kernel[oc_block - 1][0]; \ + if (oh < top || oh >= bottom) { /* oh in up or down border */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } else { /* oh in center */ \ + /* ow in right */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + /* ow in center */ \ + const float *src_w = src_h + (oh - top) * in_sh_step; \ + int ow_block = ow_block_num[oc_block - 1]; \ + for (int ow = left; ow < right; ow += ow_block) { /* left ~ right */ \ + ow_block = MSMIN(ow_block, right - ow); \ + compute_core src_w += ow_block * in_sw_step; \ + } \ + /* ow in left */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } \ + } \ + } /* output h loop */ \ + input_data += in_step; \ + output_data += out_step; \ + } /* batch loop */ \ + } +#endif // MINDSPORE_NNACL_FP32_CONV_SW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c new file mode 100644 index 00000000..5b15d4fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.c @@ -0,0 +1,99 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" +#include "nnacl_c/fp32/conv_sw.h" + +bool CheckArm64UseSWConv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->input_channel_ > C128NUM) { + return false; + } + if (conv_param->kernel_h_ > C5NUM || conv_param->kernel_w_ > C5NUM) { + return false; + } + if (conv_param->dilation_h_ != 1 || conv_param->dilation_w_ != 1) { + return false; + } + if (conv_param->stride_w_ > C3NUM) { + return false; + } + if (conv_param->input_h_ / conv_param->kernel_h_ < C48NUM || conv_param->input_w_ / conv_param->kernel_w_ < C48NUM) { + return false; + } + return true; +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[2] = {5, 5}; +#define KERNEL_LIST \ + const SWConvKernel kernel[2][5] = { \ + {SWConv1x8Kernel, SWConv2x8Kernel, SWConv3x8Kernel, SWConv4x8Kernel, SWConv5x8Kernel}, \ + {SWConv1x16Kernel, SWConv2x16Kernel, SWConv3x16Kernel, SWConv4x16Kernel, SWConv5x16Kernel}}; +#define COMPUTE_CORE \ + kernel[oc_block - 1][ow_block - 1](dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, \ + out_block_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, \ + sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); +GenerateConvSWFunc(Arm64, C2NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h new file mode 100644 index 00000000..cb38240d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_arm64_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#define NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +bool CheckArm64UseSWConv(const ConvParameter *conv_param); +void ConvSWArm64Fp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_CONV_SW_ARM64_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c new file mode 100644 index 00000000..1c979e63 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.c @@ -0,0 +1,1231 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" +#include "nnacl_c/fp32/conv_sw.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void SWConv3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %8), %%ymm14\n" + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm4, 0x20(%2)\n" + "vmovups %%ymm8, 0x40(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm9, 0x40(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm6, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "vmovups %%ymm7, 0x20(%4)\n" + "vmovups %%ymm11, 0x40(%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void SWConv4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + "vmovups 0x40(%0), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm3\n" + "vmovups 0x20(%0), %%ymm4\n" + "vmovups 0x40(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups 0x40(%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups 0x20(%0), %%ymm10\n" + "vmovups 0x40(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm3, 0x20(%2)\n" + "vmovups %%ymm6, 0x40(%2)\n" + "vmovups %%ymm9, 0x60(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm7, 0x40(%2, %1, 1)\n" + "vmovups %%ymm10, 0x60(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm3, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc4hw4 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void SWConv6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm2\n" + "vmovups 0x20(%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups 0x20(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups 0x20(%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups 0x20(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm9\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm4, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm6, (%3)\n" // dst+3 + "vmovups %%ymm7, 0x20(%3)\n" + "vmovups %%ymm8, (%3, %1, 1)\n" + "vmovups %%ymm9, 0x20(%3, %1, 1)\n" + "vmovups %%ymm10, (%3, %1, 2)\n" + "vmovups %%ymm11, 0x20(%3, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm2, 0x20(%2)\n" + "vmovups %%ymm4, 0x40(%2)\n" + "vmovups %%ymm6, 0x60(%2)\n" // dst+3 + "vmovups %%ymm8, 0x80(%2)\n" + "vmovups %%ymm10, 0xA0(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm9, 0x80(%2, %1, 1)\n" + "vmovups %%ymm11, 0xA0(%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "jmp 2f\n" + "1:\n" + // write nc8hw8 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void SWConv12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + float *dst_5 = dst + 5 * out_step; + float *dst_9 = dst + 9 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups (%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rdx\n" + "movq %4, %%r12\n" // ic_algin + "LoopIC:\n" + "vmovups (%1), %%ymm12\n" + "addq $32, %1\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg LoopIC\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %6, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), "r"(in_kh_step), // 6 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%r12", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "Write:\n" + "cmpq $13, %6\n" + "je WriteNC8HW8\n" + // write nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + "jmp End\n" + "WriteNC8HW8:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" // dst_3 + "vmovups %%ymm4, 0x80(%2)\n" + "vmovups %%ymm5, 0xA0(%2)\n" // dst_5 + "vmovups %%ymm6, 0xC0(%2)\n" + "vmovups %%ymm7, 0xE0(%2)\n" + "vmovups %%ymm8, 0x100(%2)\n" + "vmovups %%ymm9, 0x120(%2)\n" // dst_9 + "vmovups %%ymm10, 0x140(%2)\n" + "vmovups %%ymm11, 0x160(%2)\n" + "End:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv4x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + size_t src_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %%rdx, %%rax\n" + "addq $32, %1\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %8), %%ymm14\n" + "vbroadcastss (%%rax, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %9, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %2, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(kw_remainder), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(src_step) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14"); +} + +void SWConv1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm1\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm1, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + // write to nhec and nc8hw8 is identical! + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(out_step), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, + size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[4] = {12, 6, 4, 3}; +#define KERNEL_LIST \ + const SWConvKernel kernel[4][2] = {{SWConv1x8AVXKernel, SWConv12x8AVXKernel}, \ + {SWConv1x16AVXKernel, SWConv6x16AVXKernel}, \ + {SWConv1x24AVXKernel, SWConv4x24AVXKernel}, \ + {SWConv1x32AVXKernel, SWConv3x32AVXKernel}}; +#define COMPUTE_CORE \ + if (ow_block < ow_block_num[oc_block - 1]) { \ + ow_block = 1; \ + } \ + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( \ + dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, out_block_step, \ + ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, \ + oc_block, sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); + +GenerateConvSWFunc(AVX, C4NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + __m256 dst_data[12]; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int ic = 0; ic < ic_algin; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_kw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; + } + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + if (write_mode == C13NUM) { + // write nc8hw8 + _mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]); + } else { + // write nhwc + _mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]); + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h new file mode 100644 index 00000000..ec4bd2b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_sw_avx_fp32.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ + +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvSWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c new file mode 100644 index 00000000..d4e16721 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +// fp32 conv winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; + int output_tile_count = UP_DIV(output_count, tile_num); +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_per_thread = UP_DIV(output_tile_count, conv_param->thread_num_); + int start_index = block_per_thread * task_id * tile_num; + if (start_index >= output_count) { + return; + } + int end_index = MSMIN(start_index + block_per_thread * tile_num, output_count); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = start_index; out_tile_index < end_index; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} + +// fp32 conv winograd +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + for (int b = start_batch; b < end_batch; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = 0; out_tile_index < output_count; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h new file mode 100644 index 00000000..6d3b4982 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/conv_winograd_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// fp32 convolution winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func); + +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c new file mode 100644 index 00000000..c171deca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.c @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/crop_fp32.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +void Pad4DOffset(const CropParameter *crop_param, int64_t *offset, int length) { + int axis = crop_param->axis_; + for (int i = length - 1; i >= 0; --i) { + int offset_index = i - axis; + if (offset_index >= 0 && offset_index < COMM_SHAPE_SIZE) { + offset[i] = crop_param->offset_[offset_index]; + } else { + offset[i] = 0; + } + } +} + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + int out_shape1 = out_shape[1]; + int out_shape2 = out_shape[2]; + int out_shape3 = out_shape[3]; + size_t out_stride2 = out_shape3; + size_t out_stride1 = out_stride2 * out_shape2; + size_t out_stride0 = out_stride1 * out_shape1; + size_t in_stride2 = in_shape[3]; + size_t in_stride1 = in_stride2 * in_shape[2]; + size_t in_stride0 = in_stride1 * in_shape[1]; + size_t copy_size = out_shape3 * sizeof(float); + + size_t count_per_thread = UP_DIV(out_shape1, thread_num); + size_t thread_stride = thread_id * count_per_thread; + for (int i = 0; i < out_shape[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + offset_pad[0]) * in_stride0 + offset_pad[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_shape1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + offset_pad[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_shape2; ++l) { + size_t out_offset = l * out_stride2 + out_offset1; + size_t in_offset = (l + offset_pad[2]) * in_stride2 + in_offset1; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + size_t in_dim2_stride = in_shape[3]; + size_t in_dim1_stride = in_shape[2] * in_dim2_stride; + size_t in_dim0_stride = in_dim1_stride * in_shape[1]; + size_t offset_3 = offset_pad[3]; + size_t out_offset = 0; + size_t copy_num = out_shape[3]; + size_t copy_size = copy_num * sizeof(float); + size_t in_dim0_end = offset_pad[0] + out_shape[0]; + size_t in_dim1_end = offset_pad[1] + out_shape[1]; + size_t in_dim2_end = offset_pad[2] + out_shape[2]; + for (int i = offset_pad[0]; i < in_dim0_end; ++i) { + size_t dim0_offset = (size_t)i * in_dim0_stride + offset_3; + for (int j = offset_pad[1]; j < in_dim1_end; ++j) { + size_t dim1_offset = (size_t)j * in_dim1_stride + dim0_offset; + for (int k = offset_pad[2]; k < in_dim2_end; ++k) { + size_t in_offset = dim1_offset + (size_t)k * in_dim2_stride; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += copy_num; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h new file mode 100644 index 00000000..07b66e62 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/crop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CROP_FP32_H_ +#define NNACL_FP32_CROP_FP32_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num); +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_CROP_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c new file mode 100644 index 00000000..1900c27b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/cumsum_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/cumsum_fp32_simd.h" + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0.0f + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + float *layer_last_output = output + i * axis_dim * inner_dim; + float *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(Cumsum, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + // layer_output (i, j, k) = layer_input (i, j, k) + layer_last_output (i,j-1, k) + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + float *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + float *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverse, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0 + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + int32_t *layer_last_output = output + i * axis_dim * inner_dim; + int32_t *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = *(layer_input++); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + int32_t *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + int32_t *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverseInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h new file mode 100644 index 00000000..6b30cfc9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/cumsum_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CUMSUM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in new file mode 100644 index 00000000..ad5aa287 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/cumsum_fp32_simd.h.in @@ -0,0 +1,114 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, + float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_LD_F32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_MOV_F32(0.0f)); + } + return index; +} + +static inline int64_t Cumsum@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, float *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input + index); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output + index); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverse@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, + float *layer_last_output, int inner_dim) { + + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input - index - BLOCK_NUM + 1); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumIntOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, + int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_LD_EPI32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumIntOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_MOV_EPI32(0.0f)); + } + return index; +} + +static inline int64_t CumsumInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input + index); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output + index); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverseInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input - index - BLOCK_NUM + 1); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c new file mode 100644 index 00000000..290ba8fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.c @@ -0,0 +1,72 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/custom_gru_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float *input_gate = buffer[1]; + float *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2Col12MajorParallel(input + i * batch_size * input_size, buffer[0], batch_size, input_size, 0, + batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2Col12MajorParallel(init_h, buffer[C2NUM], batch_size, hidden_size, 0, batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + MatVecMulPackFp32(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + MatVecMulPackFp32(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAdd(input_gate, hidden_gate, input_gate, double_output_size); + Sigmoid(input_gate, double_output_size, input_gate); + ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size); + Tanh(input_gate, output_size, input_gate); + ElementSub(init_h, input_gate, hidden_gate, output_size); + ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAdd(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h new file mode 100644 index 00000000..16c47749 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/custom_gru_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c new file mode 100644 index 00000000..72bec29a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.c @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/deconv_fp32.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane) { + /* ichwoc(nhwc) -> oc4 * h * w * incUP4 * 4 */ + int ic_up4 = UP_ROUND(input_channel, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM; + int oc4mod = oc % C4NUM; + for (int ic = 0; ic < input_channel; ic++) { + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * plane * output_channel + hw * output_channel + oc; + int dst_index = oc4div * ic_up4 * plane * C4NUM + hw * ic_up4 * C4NUM + ic * C4NUM + oc4mod; + dst[dst_index] = weight[src_index]; + } + } + } + return; +} + +void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param) { + /* arm64 row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + /* arm32 row4x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int tile_num = 4; +#else + const int tile_num = 12; +#endif + int in_plane_round = UP_ROUND(input_plane, tile_num); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane_round * C8NUM; + int src_kh_stride = in_plane_round * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + for (int c = 0; c < oc8; c += 8) { + float *dst_ptr = tmp + c * output_plane; + const float *src_ptr = src + c * in_plane_round * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * (int)sizeof(float)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + float *tmp_dst = dst_ptr + dst_index; + const float *tmp_src = src_ptr + src_index; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s, v1.4s}, [x0] \n" + "ld1 {v2.4s, v3.4s}, [x1] \n" + + "fadd v0.4s, v0.4s, v2.4s \n" + "fadd v1.4s, v1.4s, v3.4s \n" + + "st1 {v0.4s, v1.4s}, [x1] \n" + + : + : [tmp_src] "r"(tmp_src), [tmp_dst] "r"(tmp_dst) + : "x0", "x1", "v0", "v1", "v2", "v3"); +#else + for (int i = 0; i < C8NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h new file mode 100644 index 00000000..28fe275e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_H_ +#define MINDSPORE_NNACL_FP32_DECONV_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); +void DeConvPostFp32C8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c new file mode 100644 index 00000000..67376207 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.c @@ -0,0 +1,733 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + unsigned int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + unsigned int size = conv_param->input_channel_ * conv_param->output_channel_ * tmp_kernel_plane; + float *current_unit_weight = (float *)malloc(size * sizeof(float)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float *src_ic = nhwc_weight + deconv_param->kernel_plane_ * conv_param->output_channel_ * ic; + float *dst_ic = current_unit_weight + tmp_kernel_plane * conv_param->output_channel_ * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float *src_hw = + src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * conv_param->output_channel_; + float *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * conv_param->output_channel_; + memcpy(dst_hw, src_hw, conv_param->output_channel_ * sizeof(float)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64], matrix_a[64], matrix_b[64]; + float matrix_gt[64], matrix_at[64], matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.AT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.AT_, matrix_at, unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.BT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.BT_, matrix_bt, unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd Weight */ + size = conv_param->input_channel_ * conv_param->output_channel_ * unit->winograd_.kh_ * unit->winograd_.kw_; + float *winograd_unit_weight = (float *)malloc(size * sizeof(float)); + if (winograd_unit_weight == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + return NNACL_NULL_PTR; + } + WinogradWeightTransform(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, tile_num, + unit->winograd_.kh_, unit->h_size_, conv_param->output_channel_, conv_param->input_channel_, + false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = NULL; + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float *dst_weight = (float *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < conv_param->output_channel_; oc++) { + int oc4div = oc / tile_num, oc4mod = oc % tile_num; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * conv_param->output_channel_ * tmp_kernel_plane + upi * conv_param->output_channel_ + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * tile_num * deconv_param->ic_up_ + + ic * tile_num + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_OK; +} + +void DeConvWgInputPack(const float *src_ptr, float *dst_ptr, int channel, int stride) { +#ifdef ENABLE_AVX + int ic_tile = C8NUM; +#else + int ic_tile = C4NUM; +#endif + int ic4div = channel / ic_tile; + int ic4mod = channel % ic_tile; + const float *src = src_ptr; + float *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst, MS_LD256_F32(src)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst, MS_LDQ_F32(src)); +#else + memcpy(dst, src, C4NUM * sizeof(float)); +#endif + dst += stride; + src += ic_tile; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < ic_tile; ic_res++) { + dst[ic_res] = 0; + } + } +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + int dx, sz, dz; + const int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float *dst_z = dst + dz * cal_num; + const float *weight_dz = weight + dz * ic4 * C16NUM; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float *src_z = src_dx + sz * src_depth_step; + const float *weight_z = weight_dz + sz * C16NUM; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r11, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r11], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r11], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r8", "r10", "r11", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#else +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r7, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r7], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r7], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r8", "r10", "r7", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#endif +#endif + +#ifdef ENABLE_AVX +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 src5 = MS_LD256_F32(src_ptr + 4 * src_stride); + MS_FLOAT32X8 src6 = MS_LD256_F32(src_ptr + 5 * src_stride); + MS_FLOAT32X8 src7 = MS_LD256_F32(src_ptr + 6 * src_stride); + MS_FLOAT32X8 src8 = MS_LD256_F32(src_ptr + 7 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + MS_FLOAT32X8 dst5 = MS_LD256_F32(dst_ptr + 4 * dst_stride); + MS_FLOAT32X8 dst6 = MS_LD256_F32(dst_ptr + 5 * dst_stride); + MS_FLOAT32X8 dst7 = MS_LD256_F32(dst_ptr + 6 * dst_stride); + MS_FLOAT32X8 dst8 = MS_LD256_F32(dst_ptr + 7 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + dst5 = MS_ADD256_F32(dst5, src5); + dst6 = MS_ADD256_F32(dst6, src6); + dst7 = MS_ADD256_F32(dst7, src7); + dst8 = MS_ADD256_F32(dst8, src8); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + MS_ST256_F32(dst_ptr + 4 * dst_stride, dst5); + MS_ST256_F32(dst_ptr + 5 * dst_stride, dst6); + MS_ST256_F32(dst_ptr + 6 * dst_stride, dst7); + MS_ST256_F32(dst_ptr + 7 * dst_stride, dst8); + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { + MS_FLOAT32X8 src_data = MS_LD256_F32(src_ptr); + MS_FLOAT32X8 dst_data = MS_LD256_F32(dst_ptr); + dst_data = MS_ADD256_F32(src_data, dst_data); + MS_ST256_F32(dst_ptr, dst_data); + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#else +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "ld1 {v4.4s}, [x7], %[src_step]\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v4.4s}, [x7], %[src_step]\n" + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + : + : [src_ptr] "r"(src_ptr), [dst_ptr] "r"(dst_ptr), [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM32) + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + DeConvWgMergeArm32(src_ptr, dst_ptr, src_step, dst_step); +#elif defined(ENABLE_SSE) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); + MS_STQ_F32(dst_ptr + C4NUM * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + C4NUM * dst_stride), MS_LDQ_F32(src_ptr + C4NUM * src_stride))); + MS_STQ_F32(dst_ptr + 5 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 5 * dst_stride), MS_LDQ_F32(src_ptr + 5 * src_stride))); + MS_STQ_F32(dst_ptr + 6 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 6 * dst_stride), MS_LDQ_F32(src_ptr + 6 * src_stride))); + MS_STQ_F32(dst_ptr + 7 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 7 * dst_stride), MS_LDQ_F32(src_ptr + 7 * src_stride))); +#else + for (int j = 0; j < C8NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); +#else + for (int j = 0; j < C4NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src_ptr); + MS_FLOAT32X4 dst_data = MS_LDQ_F32(dst_ptr); + dst_data = MS_ADDQ_F32(src_data, dst_data); + MS_STQ_F32(dst_ptr, dst_data); +#else + for (int j = 0; j < C4NUM; j++) { + dst_ptr[j] += src_ptr[j]; + } +#endif + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#endif + +void DeConvWgCalWgFp32(const float *tile_in, float *tile_out, const float *weight_buf, float *tmp_buf, + const float *at_buf, float *a_mid_buf, float *trans_a_buf, bool *transferred, + const float *bt_buf, float *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + matmul_func(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, + deconv_param->oc_div_); + } + WinogradTransLeft(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } +} + +void DeConvWgCalCommFp32(const float *tile_in, float *tile_out, const float *weight, float *tmp_buf, int h_start, + int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + matmul_func(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } +} + +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + if (deconv_param->in_tile_w_count_ == 0) { + return NNACL_ERR; + } + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; +#ifdef ENABLE_AVX + int tile_num = C8NUM; + MS_FLOAT32X8 zero = MS_MOV256_F32(0.0f); +#else + int tile_num = C4NUM; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); +#endif +#endif + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float *dst_unit = tile_in + unit_index * tile_num; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#else + for (int i = 0; i < tile_num; i++) { + dst[tile_num * DECONV_WINOGRAD_DEFAULT_TILE * ic4_index + i] = 0; + } +#endif + } + continue; + } + + const float *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPack(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * tile_num); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT || unit->winograd_.AT_ == NULL) { + return NNACL_ERR; + } + DeConvWgABuffer *wg_buf = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float *wg_mid_a_buf = (float *)wg_buf->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *wg_dst_a_buf = (float *)wg_buf->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + DeConvWgCalWgFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, + wg_dst_a_buf, transferred, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_, + unit->h_start_, conv_param, deconv_param); + } else { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * unit->h_size_ * + DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + DeConvWgCalCommFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return NNACL_OK; +} + +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * tile_num; + + for (int index = 0; index < calculate_count; ++index) { + const float *src_start = tile_out + index * tile_num; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * tile_num + w_start * tile_num; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float *dst = dst_start + (hi * conv_param->output_w_ + wi) * tile_num; + DeConvWgMerge(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h new file mode 100644 index 00000000..576ce4e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/deconv_winograd_fp32.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*TiledMatmulFp32)(float *dst, const float *src, const float *weight, size_t ic_tiled, size_t cal_num, + size_t oc_tiled); + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t ic4, size_t cal_num, size_t oc4); +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t ic8, size_t cal_num, size_t oc8); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c new file mode 100644 index 00000000..45c29cbf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.c @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/detection_post_process_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_utils.h" + +float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) { + const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin); + const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin); + if (area_a <= 0 || area_b <= 0) { + return 0.0f; + } + const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin; + const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin; + const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax; + const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax; + const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f; + const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f; + const float inter = h * w; + return inter / (area_a + area_b - inter); +} + +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param) { + if (input_boxes == NULL || anchors == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + float *decoded_boxes = (float *)param->decoded_boxes_; + BboxCenter scaler; + scaler.y = param->y_scale_; + scaler.x = param->x_scale_; + scaler.h = param->h_scale_; + scaler.w = param->w_scale_; + for (int i = 0; i < num_boxes; ++i) { + BboxCenter *box = (BboxCenter *)(input_boxes) + i; + BboxCenter *anchor = (BboxCenter *)(anchors) + i; + BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes) + i; + float y_center = box->y / scaler.y * anchor->h + anchor->y; + float x_center = box->x / scaler.x * anchor->w + anchor->x; + const float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h; + const float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w; + decoded_box->ymin = y_center - h_half; + decoded_box->xmin = x_center - w_half; + decoded_box->ymax = y_center + h_half; + decoded_box->xmax = x_center + w_half; + } + return NNACL_OK; +} + +int NmsSingleClass(const int num_boxes, const float *decoded_boxes, const int max_detections, const float *scores, + int32_t *selected, void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + uint8_t *nms_candidate = param->nms_candidate_; + const int output_num = num_boxes < max_detections ? num_boxes : max_detections; + int possible_candidate_num = num_boxes; + int selected_num = 0; + int32_t *indexes = (int32_t *)param->single_class_indexes_; + for (int i = 0; i < num_boxes; ++i) { + indexes[i] = i; + nms_candidate[i] = 1; + } + PartialArgSort(scores, indexes, num_boxes, num_boxes); + for (int i = 0; i < num_boxes; ++i) { + if (possible_candidate_num == 0 || selected_num >= output_num || scores[indexes[i]] < param->nms_score_threshold_) { + break; + } + if (nms_candidate[indexes[i]] == 0) { + continue; + } + selected[selected_num++] = indexes[i]; + nms_candidate[indexes[i]] = 0; + possible_candidate_num--; + const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + indexes[i]; + for (int t = i + 1; t < num_boxes; ++t) { + if (scores[indexes[t]] < param->nms_score_threshold_) break; + if (nms_candidate[indexes[t]] == 1) { + const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + indexes[t]; + const float iou = IntersectionOverUnion(bbox_i, bbox_t); + if (iou > param->nms_iou_threshold_) { + nms_candidate[indexes[t]] = 0; + possible_candidate_num--; + } + } + } + } + return selected_num; +} + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param, const int task_id, const int thread_num) { + if (input_scores == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + float *scores = (float *)param->scores_; + for (int i = task_id; i < num_boxes; i += thread_num) { + int32_t *indexes = (int32_t *)param->indexes_ + i * param->num_classes_; + for (int j = 0; j < param->num_classes_; ++j) { + indexes[j] = i * num_classes_with_bg + first_class_index + j; + } + PartialArgSort(input_scores, indexes, max_classes_per_anchor, param->num_classes_); + scores[i] = input_scores[indexes[0]]; + } + return NNACL_OK; +} + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || decoded_boxes == NULL || output_boxes == NULL || output_classes == NULL || + output_scores == NULL || output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + int out_num = 0; + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + int32_t *selected = (int32_t *)param->selected_; + int selected_num = NmsSingleClass(num_boxes, decoded_boxes, param->max_detections_, (float *)param->scores_, selected, + PartialArgSort, param); + for (int i = 0; i < selected_num; ++i) { + int32_t *indexes = (int32_t *)param->indexes_ + selected[i] * param->num_classes_; + BboxCorner *box = (BboxCorner *)(decoded_boxes) + selected[i]; + for (int j = 0; j < max_classes_per_anchor; ++j) { + *((BboxCorner *)(output_boxes) + out_num) = *box; + output_scores[out_num] = input_scores[indexes[j]]; + NNACL_ASSERT(num_classes_with_bg != 0); + output_classes[out_num++] = (float)(indexes[j] % num_classes_with_bg - first_class_index); + } + } + *output_num = (float)out_num; + for (int i = out_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_scores[i] = 0; + output_classes[i] = 0; + } + return NNACL_OK; +} + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || output_boxes == NULL || output_classes == NULL || output_scores == NULL || + output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + float *decoded_boxes = (float *)param->decoded_boxes_; + int32_t *selected = (int32_t *)param->selected_; + float *scores = (float *)param->scores_; + float *all_scores = (float *)param->all_class_scores_; + int32_t *indexes = (int32_t *)(param->indexes_); + int32_t *all_indexes = (int32_t *)(param->all_class_indexes_); + int all_classes_sorted_num = 0; + int all_classes_output_num = 0; + for (int j = first_class_index; j < num_classes_with_bg; ++j) { + // process single class + for (int i = 0; i < num_boxes; ++i) { + scores[i] = input_scores[i * num_classes_with_bg + j]; + } + int selected_num = + NmsSingleClass(num_boxes, decoded_boxes, param->detections_per_class_, scores, selected, PartialArgSort, param); + for (int i = 0; i < all_classes_sorted_num; ++i) { + indexes[i] = all_indexes[i]; + all_indexes[i] = i; + } + // process all classes + for (int i = 0; i < selected_num; ++i) { + indexes[all_classes_sorted_num] = selected[i] * num_classes_with_bg + j; + all_indexes[all_classes_sorted_num] = all_classes_sorted_num; + all_scores[all_classes_sorted_num++] = scores[selected[i]]; + } + all_classes_output_num = + all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_; + PartialArgSort(all_scores, all_indexes, all_classes_output_num, all_classes_sorted_num); + for (int i = 0; i < all_classes_output_num; ++i) { + scores[i] = all_scores[all_indexes[i]]; + all_indexes[i] = indexes[all_indexes[i]]; + } + for (int i = 0; i < all_classes_output_num; ++i) { + all_scores[i] = scores[i]; + } + all_classes_sorted_num = all_classes_output_num; + } + for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + if (i < all_classes_output_num) { + NNACL_CHECK_ZERO_RETURN_ERR(num_classes_with_bg); + const int box_index = all_indexes[i] / num_classes_with_bg; + const int class_index = all_indexes[i] % num_classes_with_bg - first_class_index; + *((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index); + output_classes[i] = (float)class_index; + output_scores[i] = all_scores[i]; + } else { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_classes[i] = 0.0f; + output_scores[i] = 0.0f; + } + } + *output_num = (float)all_classes_output_num; + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h new file mode 100644 index 00000000..4df46a79 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/detection_post_process_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ +#define MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/detection_post_process_parameter.h" + +typedef struct { + float y; + float x; + float h; + float w; +} BboxCenter; + +typedef struct { + float ymin; + float xmin; + float ymax; + float xmax; +} BboxCorner; + +#ifdef __cplusplus +extern "C" { +#endif +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param); + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*)(const float *, int32_t *, int, int), const DetectionPostProcessParameter *param, + const int task_id, const int thread_num); + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c new file mode 100644 index 00000000..a0aa229b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.c @@ -0,0 +1,136 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/div_fp32.h" +#include +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/div_fp32_simd.h" + +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[0], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } + return NNACL_OK; +} + +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index] != 0); + out[index] = in0[0] / in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0] != 0); + + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementDiv(tile_in0, tile_in1, out, size); +} + +int ElementDiv(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDiv, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[index]; + } + return NNACL_OK; +} + +int ElementDivRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] / in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h new file mode 100644 index 00000000..8a966e39 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DIV_H_ +#define MINDSPORE_NNACL_FP32_DIV_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementDiv(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DIV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in new file mode 100644 index 00000000..1495b6f5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/div_fp32_simd.h.in @@ -0,0 +1,160 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c new file mode 100644 index 00000000..65630afc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/dropout_fp32.h" +#include "nnacl_c/dropout_fp32_simd.h" + +void DropoutFp32(const float *input, float scale, int length, float *output) { + int i = 0; + + SIMD_RUN_NO_SCALAR(DropoutFp32, i, input, scale, length, output); + + for (; i < length; ++i) { + output[i] = scale * input[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h new file mode 100644 index 00000000..50193b60 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32.h @@ -0,0 +1,28 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ +#define MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DropoutFp32(const float *input, float scale, int length, float *output); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in new file mode 100644 index 00000000..36109cd8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/dropout_fp32_simd.h.in @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int DropoutFp32@SIMD_INSTRUCTION@(int index, const float *input, float scale, + int length, float *output) { + SIMD_F32 scale_value = SIMD_MOV_F32(scale); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), scale_value)); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c new file mode 100644 index 00000000..f2c6425c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.c @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/embedding_lookup_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +void l2_regulate(float *data, int size, float max_norm) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += data[i]; + } + if (sum != 0) { + for (int i = 0; i < size; ++i) { + data[i] *= max_norm / sum; + } + } + return; +} + +int CopyData(float *input_data, const int32_t *ids, float *output_data, int num, + const EmbeddingLookupParameter *parameter) { + if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + float *out_data = output_data + num * parameter->layer_size_; + float *in_data = input_data + ids[num] * parameter->layer_size_; + if (!parameter->is_regulated_[ids[num]]) { + l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); + parameter->is_regulated_[ids[num]] = true; + } + + memcpy(out_data, in_data, sizeof(float) * (size_t)(parameter->layer_size_)); + return NNACL_OK; +} + +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id) { + if (parameter->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + for (int i = task_id; i < parameter->ids_size_; i += parameter->op_parameter_.thread_num_) { + int ret = CopyData(input_data, ids, output_data, i, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h new file mode 100644 index 00000000..236d053d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/embedding_lookup_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ + +#include "nnacl_c/op_base.h" + +typedef struct EmbeddingLookupParameter { + OpParameter op_parameter_; + // primitive parameter + float max_norm_; + + // shape correlative + bool *is_regulated_; + int ids_size_; + int layer_size_; + int layer_num_; +} EmbeddingLookupParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c new file mode 100644 index 00000000..cab8c075 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.c @@ -0,0 +1,62 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/exp_fp32_simd.h" +#include +#include +#include "nnacl_c/errorcode.h" + +void ExpFp32(const float *src, float *dst, int num) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ExpFp32, i, src, dst, num); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + } +} + +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + const float *src = (const float *)src_data; + float *dst = (float *)dst_data; + + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp32(src + start, dst + start, num); + } else { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithInScale, i, src, dst, num, exp->in_scale_); + for (; i < num; ++i) { + simd_exp32(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithOutScale, i, src, dst, num, exp->out_scale_); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h new file mode 100644 index 00000000..d02a0562 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EXP_H_ +#define MINDSPORE_NNACL_FP32_EXP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/exp_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ExpFp32(const float *src, float *dst, int num); +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EXP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in new file mode 100644 index 00000000..e884e43e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/exp_fp32_simd.h.in @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ExpFp32@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithInScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float in_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(in_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_MUL_F32(SIMD_LD_F32(src + index), scale_vec), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithOutScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float out_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(out_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_LD_F32(dst + index), scale_vec)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c new file mode 100644 index 00000000..ed9e6cf2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len) { + int i = 0; + for (i = 0; i < count; i++) { + (void)memcpy((int8_t *)output + area * i * data_type_len, (int8_t *)input + in_offset[i] * data_type_len, + (size_t)(area)*data_type_len); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h new file mode 100644 index 00000000..ffe80119 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gatherNd_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GATHERND_FP32_H_ +#define NNACL_FP32_GATHERND_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GATHERND_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c new file mode 100644 index 00000000..d6de5be2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/group_norm_fp32.h" +#include +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/group_norm_fp32_simd.h" + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param); + +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int frame_elem_num = param->unit_ * param->channel_; + const int groups_per_thread = UP_DIV(param->num_groups_, param->op_parameter_.thread_num_); + const int completed_group = task_id * groups_per_thread; + const int cur_group = MSMIN(groups_per_thread, param->num_groups_ - completed_group); + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + int cur_offset = completed_group * num_of_ch_per_group * param->unit_; + + for (int b = 0; b < param->batch_; b++) { + const float *b_in = input + b * frame_elem_num; + float *b_out = output + b * frame_elem_num; + int b_offset = cur_offset; + GroupNormFp32MeanVar(b_in, mean, variance, completed_group, cur_group, param); + for (int g = 0; g < cur_group; g++) { + int grp_idx = g + completed_group; + int c_offset = grp_idx * num_of_ch_per_group; + float m = mean[grp_idx]; + float v = variance[grp_idx]; + float variance_sqrt = sqrtf(v + param->epsilon_); + if (variance_sqrt == 0) { + return NNACL_ERR; + } + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *unit_input = b_in + b_offset; + float *unit_output = b_out + b_offset; + float s = scale[c_offset + c]; + float o = offset[c_offset + c]; + int u = 0; + SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output); + for (; u < param->unit_; u++) { + float norm_val = (unit_input[u] - m) / variance_sqrt; + unit_output[u] = norm_val * s + o; + } + b_offset += param->unit_; + } + } + } + return NNACL_OK; +} + +#define SimdReduceSum(block_size, block_num, in, i, sum) \ + do { \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, in + i); \ + sum += MS_GET_SUM_F32(block_size, input); \ + } \ + } while (0) + +#define SimdReduceVar(block_size, block_num, in, m, i, sum) \ + do { \ + MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \ + MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0); \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_SUB_F32(block_size, MS_LD_F32(block_size, in + i), mean); \ + tmp = MS_ADD_F32(block_size, tmp, MS_MUL_F32(block_size, input, input)); \ + } \ + sum += MS_GET_SUM_F32(block_size, tmp); \ + } while (0) + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param) { + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + const float N = (float)(param->unit_ * num_of_ch_per_group); + + // calc mean + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float sum = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_); + for (; i < param->unit_; i++) { + sum += in[i]; + } + } + run_mean[g_idx] = sum / N; + } + + // calc variance + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float var = 0; + run_var[g_idx] = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_); + for (; i < param->unit_; i++) { + var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]); + } + } + run_var[g_idx] = var / N; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h new file mode 100644 index 00000000..f5b595ae --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/group_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in new file mode 100644 index 00000000..33bde9b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/group_norm_fp32_simd.h.in @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean, + float var_sqrt, int unit, float *unit_output) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt); + SIMD_F32 scale_val = SIMD_MOV_F32(scale); + SIMD_F32 offset_val = SIMD_MOV_F32(offset); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(unit_input + index); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt); + SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val); + SIMD_ST_F32(unit_output + index, output); + } + return index; +} + +static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val); + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c new file mode 100644 index 00000000..cf113430 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/gru_fp32.h" +#include +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +void GruMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { + if (is_vec) { + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight, + const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) { + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float *state_update_weight = state_weight; + const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float *state_update_gate = state_gate; + float *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float *state_update_bias = state_bias; + const float *state_reset_bias = state_bias + gru_param->hidden_size_; + const float *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + GruMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate); + // update update_gate + Sigmoid(update_gate, gru_param->batch_ * gru_param->hidden_size_, update_gate); + + ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + GruMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + Tanh(hidden_buffer, gru_param->batch_ * gru_param->hidden_size_, hidden_buffer); + + ElementMul(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + const float one = 1.0f; + ElementOptSub(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float)); +} + +void GruUnidirectional(float *output, const float *packed_input, const float *weight_g, const float *weight_r, + const float *input_bias, const float *state_bias, float *hidden_state, float *buffer[4], + const GruParameter *gru_param, bool is_backward) { + float *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float *bias_loop = input_bias + gru_param->input_col_align_ * i; + float *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float *update_gate = gate; + float *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnit(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, buffer, + gru_param); + } +} + +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_param) { + // forward + float *packed_input = buffer[0]; + PackLstmInput(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_); + GruUnidirectional(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, gru_param, + false); + + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + + GruUnidirectional(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h new file mode 100644 index 00000000..cba05c1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/gru_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_GRU_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c new file mode 100644 index 00000000..455eb071 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.c @@ -0,0 +1,374 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/instance_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, param->channel_); + + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * param->channel_ * param->inner_size_; + float *dst_b = dst_data + b * param->channel_ * param->inner_size_; + for (int c = channel_begin; c < channel_end; c++) { + const float *src = src_b + c * param->inner_size_; + float *dst = dst_b + c * param->inner_size_; + double mean = 0.0f; + double squ_m = 0.0f; + + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + __m256 srcv = _mm256_loadu_ps(src + index); + __m256 squarev = _mm256_mul_ps(srcv, srcv); + __m128 src128 = _mm_add_ps(_mm256_extractf128_ps(srcv, 0), _mm256_extractf128_ps(srcv, 1)); + __m128 square128 = _mm_add_ps(_mm256_extractf128_ps(squarev, 0), _mm256_extractf128_ps(squarev, 1)); + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(src128, i); + squ_m += MS_F32X4_GETI(square128, i); + } + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); +#ifdef ENABLE_ARM64 + mean += vaddvq_f32(srcv); + squ_m += vaddvq_f32(squarev); +#elif defined(ENABLE_SSE) + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(srcv, i); + squ_m += MS_F32X4_GETI(squarev, i); + } +#else + float32x2_t src_add2 = vadd_f32(vget_low_f32(srcv), vget_high_f32(srcv)); + float32x2_t src_add4 = vpadd_f32(src_add2, src_add2); + mean += vget_lane_f32(src_add4, 0); + float32x2_t square_add2 = vadd_f32(vget_low_f32(squarev), vget_high_f32(squarev)); + float32x2_t square_add4 = vpadd_f32(square_add2, square_add2); + squ_m += vget_lane_f32(square_add4, 0); +#endif + } +#endif + for (; index < param->inner_size_; index++) { + mean += src[index]; + squ_m += src[index] * src[index]; + } + + mean /= (float)param->inner_size_; + squ_m /= (float)param->inner_size_; + const double deno = gamma_data[c] / sqrt(squ_m - mean * mean + param->epsilon_); + + index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 meanv8 = MS_MOV256_F32(mean); + MS_FLOAT32X8 denov8 = MS_MOV256_F32(deno); + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + MS_FLOAT32X8 srcv8 = MS_LD256_F32(src + index); + MS_FLOAT32X8 dstv8 = + MS_ADD256_F32(MS_MUL256_F32(MS_SUB256_F32(srcv8, meanv8), denov8), MS_MOV256_F32(*(beta_data + c))); + MS_ST256_F32(dst + index, dstv8); + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 meanv4 = MS_MOVQ_F32(mean); + MS_FLOAT32X4 denov4 = MS_MOVQ_F32(deno); + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv4 = MS_LDQ_F32(src + index); + MS_FLOAT32X4 dstv4 = + MS_ADDQ_F32(MS_MULQ_F32(MS_SUBQ_F32(srcv4, meanv4), denov4), MS_MOVQ_F32(*(beta_data + c))); + MS_STQ_F32(dst + index, dstv4); + } +#endif + for (; index < param->inner_size_; index++) { + dst[index] = (src[index] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) +void InstanceNormC4HW4ArmSse(const float *src_b, float *dst_b, const float *gamma_data, const float *beta_data, + int32_t *c_src, const InstanceNormParameter *param, int channel, int channel_end, + int hw_plane, MS_FLOAT32X4 hw_planev) { + int c = *c_src; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + const float *src2 = src_b + (c + C8NUM) * hw_plane, *src3 = src_b + (c + C12NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f), mean3 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m2 = MS_MOVQ_F32(0.0f), squ_m3 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2), squarev3 = MS_MULQ_F32(srcv3, srcv3); + MS_ADDQ_F32_VEC(mean, mean1, mean2, mean3, srcv, srcv1, srcv2, srcv3); + MS_ADDQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, squarev, squarev1, squarev2, squarev3); + } + MS_DIVQ_F32_VEC(mean, mean1, mean2, mean3, hw_planev); + MS_DIVQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, hw_planev); + + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno2 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno3 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); + + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); + deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] + MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM), betav3 = MS_LDQ_F32(beta_data + c + C12NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2), outv3 = MS_SUBQ_F32(srcv3, mean3); + + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv2 = MS_MULQ_F32(outv2, gammav2), outv3 = MS_MULQ_F32(outv3, gammav3); + MS_ADDQ_F32_VEC(outv, outv1, outv2, outv3, betav, betav1, betav2, betav3); + + MS_STQ_F32(dst + index * channel, outv), MS_STQ_F32(dst + index * channel + C4NUM, outv1); + MS_STQ_F32(dst + index * channel + C8NUM, outv2), MS_STQ_F32(dst + index * channel + C12NUM, outv3); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + mean = MS_ADDQ_F32(mean, srcv), mean1 = MS_ADDQ_F32(mean1, srcv1); + squ_m = MS_ADDQ_F32(squ_m, squarev), squ_m1 = MS_ADDQ_F32(squ_m1, squarev1); + } + + MS_DIVQ_F32_VEC(mean, mean1, squ_m, squ_m1, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv = MS_ADDQ_F32(outv, betav), outv1 = MS_ADDQ_F32(outv1, betav1); + MS_STQ_F32(dst + index * channel, outv); + MS_STQ_F32(dst + index * channel + C4NUM, outv1); + } + } + for (; c <= channel_end - C4NUM; c += C4NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), squ_m = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), squarev = MS_MULQ_F32(srcv, srcv); + mean = MS_ADDQ_F32(mean, srcv), squ_m = MS_ADDQ_F32(squ_m, squarev); + } + mean = MS_DIVQ_F32(mean, hw_planev), squ_m = MS_DIVQ_F32(squ_m, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno), betav = MS_LDQ_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), outv = MS_SUBQ_F32(srcv, mean); + MS_STQ_F32(dst + index * channel, MS_ADDQ_F32(MS_MULQ_F32(outv, gammav), betav)); + } + } + *c_src = c; +} +#endif + +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C4NUM), param->op_parameter_.thread_num_) * C4NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 hw_planev = MS_MOVQ_F32((float)(hw_plane)); +#endif + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + InstanceNormC4HW4ArmSse(src_b, dst_b, gamma_data, beta_data, &c, param, channel, channel_end, hw_plane, hw_planev); +#endif + for (; c < channel_end; ++c) { + int c4_down_loop = c / C4NUM * C4NUM; + int c4_mod = c % C4NUM; + int c_res = MSMIN(channel_end - c4_down_loop, C4NUM); + const float *src = src_b + c4_down_loop * hw_plane + c4_mod; + float *dst = dst_b + c; + float mean = 0.0f; + float squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_, hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + MS_FLOAT32X8 hw_planev = MS_MOV256_F32((float)(hw_plane)); + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane; + const float *src1 = src_b + (c + C8NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), mean1 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 squ_m = MS_MOV256_F32(0.0f), squ_m1 = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv), squarev1 = MS_MUL256_F32(srcv1, srcv1); + mean = MS_ADD256_F32(mean, srcv); + mean1 = MS_ADD256_F32(mean1, srcv1); + squ_m = MS_ADD256_F32(squ_m, squarev); + squ_m1 = MS_ADD256_F32(squ_m1, squarev1); + } + mean = MS_DIV256_F32(mean, hw_planev); + mean1 = MS_DIV256_F32(mean1, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + squ_m1 = MS_DIV256_F32(squ_m1, hw_planev); + MS_FLOAT32X8 deno = + MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), MS_MOV256_F32(param->epsilon_)); + MS_FLOAT32X8 deno1 = + MS_ADD256_F32(MS_SUB256_F32(squ_m1, MS_MUL256_F32(mean1, mean1)), MS_MOV256_F32(param->epsilon_)); + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + deno1 = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno1)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 gammav1 = MS_MUL256_F32(MS_LD256_F32(gamma_data + c + C8NUM), deno1); // deno1 * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c), betav1 = MS_LD256_F32(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean), outv1 = MS_SUB256_F32(srcv1, mean1); + outv = MS_MUL256_F32(outv, gammav); + outv1 = MS_MUL256_F32(outv1, gammav1); + outv = MS_ADD256_F32(outv, betav); + outv1 = MS_ADD256_F32(outv1, betav1); + MS_ST256_F32(dst + index * channel, outv); + MS_ST256_F32(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), squ_m = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv); + mean = MS_ADD256_F32(mean, srcv); + squ_m = MS_ADD256_F32(squ_m, squarev); + } + mean = MS_DIV256_F32(mean, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + MS_FLOAT32X8 deno = MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), + MS_MOV256_F32(param->epsilon_)); // 256uestion + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), outv = MS_SUB256_F32(srcv, mean); + outv = MS_MUL256_F32(outv, gammav); + outv = MS_ADD256_F32(outv, betav); + MS_ST256_F32(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + int c8_down_loop = c / C8NUM * C8NUM, c8_mod = c % C8NUM; + int c_res = MSMIN(channel_end - c8_down_loop, C8NUM); + const float *src = src_b + c8_down_loop * hw_plane + c8_mod; + float *dst = dst_b + c; + float mean = 0.0f, squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h new file mode 100644 index 00000000..6d908e59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/instance_norm_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ +#define MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/instance_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MS_ADDQ_F32_VEC(in1, in2, in3, in4, v1, v2, v3, v4) \ + in1 = MS_ADDQ_F32(in1, v1); \ + in2 = MS_ADDQ_F32(in2, v2); \ + in3 = MS_ADDQ_F32(in3, v3); \ + in4 = MS_ADDQ_F32(in4, v4); + +#define MS_DIVQ_F32_VEC(in1, in2, in3, in4, v) \ + in1 = MS_DIVQ_F32(in1, v); \ + in2 = MS_DIVQ_F32(in2, v); \ + in3 = MS_DIVQ_F32(in3, v); \ + in4 = MS_DIVQ_F32(in4, v); + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c new file mode 100644 index 00000000..04662b54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/invert_permutation_fp32.h" +#include "nnacl_c/errorcode.h" + +int InvertPermutation(const int32_t *input, int32_t *output, size_t num) { + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + for (size_t i = 0; i < num; i++) { + size_t index = (size_t)input[i]; + if (index >= num) { + return NNACL_ERR; + } + output[index] = i; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h new file mode 100644 index 00000000..d9cf1917 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/invert_permutation_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ +#define MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int InvertPermutation(const int32_t *input, int32_t *output, size_t num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c new file mode 100644 index 00000000..56087123 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.c @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/l2_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end) { + *sum = 0.0f; + int i; + for (i = begin; i < end; ++i) { + *sum += input_ptr[i] * input_ptr[i]; + } + return NNACL_OK; +} + +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + int i; + if (sqrt_sum == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (i = begin; i < end; i++) { + float tmp = input_ptr[i] / sqrt_sum; + if (is_relu) { + output_ptr[i] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i] = tmp; + } + } + return NNACL_OK; +} + +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + + const int c = param->shape_[param->shape_num_ - 1]; + int i = 0; + for (i = begin; i < end; ++i) { + float square_sum = 0.0f; + int j = 0; + for (j = 0; j < c; ++j) { + const float val = input_ptr[i * c + j]; + square_sum += val * val; + } + float sqrt_sum = sqrtf(square_sum > param->epsilon_ ? square_sum : param->epsilon_); + for (j = 0; j < c; ++j) { + float tmp = input_ptr[i * c + j] / sqrt_sum; + if (is_relu) { + output_ptr[i * c + j] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i * c + j] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i * c + j] = tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h new file mode 100644 index 00000000..2af8506d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/l2_norm_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ +#define MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ + +#include "nnacl_c/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end); +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end); +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c new file mode 100644 index 00000000..e311fec5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/layer_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/layer_norm_fp32_simd.h" + +int LayerNormMeanAndSquare(const float *src, int num, float *mean, float *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float square_mean = 0.f; + + SIMD_RUN_NO_SCALAR(LayerNormMeanAndSquare, index, src, num, mean, &square_mean); + + for (; index < num; index++) { + *mean += src[index]; + square_mean += src[index] * src[index]; + } + *mean /= (float)num; + square_mean /= (float)num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data, const float *beta_data, int num, + const float mean, const float deno) { + int index = 0; + + SIMD_RUN_NO_SCALAR(LayerNormGammaAndBeta, index, dst, src, gamma_data, beta_data, num, mean, deno); + + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN(((int)task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float *src_norm = src_data + i * param->norm_inner_size_; + float *dst_norm = dst_data + i * param->norm_inner_size_; + float cur_mean = 0.0f; + float cur_variance = 0.0f; + int ret = LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float *src_param = src_norm + x * param->params_inner_size_; + float *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h new file mode 100644 index 00000000..cfb0aaf2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_LAYER_NORM_FP32_H_ +#define NNACL_FP32_LAYER_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LAYER_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in new file mode 100644 index 00000000..e05c5666 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/layer_norm_fp32_simd.h.in @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int LayerNormMeanAndSquare@SIMD_INSTRUCTION@(int index, const float *src, int num, float *mean, float *square_mean) { + if (num >= 4 * BLOCK_NUM) { + SIMD_F32 sum_val = SIMD_SET0_F32; + SIMD_F32 square_sum_val = SIMD_SET0_F32; + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 square_value = SIMD_MUL_F32(value, value); + sum_val = SIMD_ADD_F32(sum_val, value); + square_sum_val = SIMD_ADD_F32(square_sum_val, square_value); + } + *mean += SIMD_GET_SUM_F32(sum_val); + *square_mean += SIMD_GET_SUM_F32(square_sum_val); + } + return index; +} + +static inline int LayerNormGammaAndBeta@SIMD_INSTRUCTION@(int index, float *dst, const float *src, const float *gamma_data, + const float *beta_data, int num, const float mean, const float deno) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 deno_val = SIMD_MOV_F32(deno); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 out_value = SIMD_SUB_F32(value, mean_val); + out_value = SIMD_MUL_F32(out_value, deno_val); + out_value = SIMD_FMADD_F32(out_value, SIMD_LD_F32(gamma_data + index), SIMD_LD_F32(beta_data + index)); + SIMD_ST_F32(dst + index, out_value); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c new file mode 100644 index 00000000..25cdba7e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.c @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/local_response_norm_fp32.h" +#include +#include "nnacl_c/errorcode.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param) { + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(param); + int64_t depth_radius = param->depth_radius_; + float bias = param->bias_; + float alpha = param->alpha_; + float beta = param->beta_; + + for (int i = 0; i < out_size; i++) { + const float *in_data = input_ptr + i * channel; + float *out_data = output_ptr + i * channel; + // border_left + for (int j = 0; j < MSMIN(depth_radius, channel); j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + // center + if (2 * depth_radius + 1 < channel) { + float tmp_sum = 0.0f; + for (int j = 0; j < depth_radius * 2 + 1; ++j) { + tmp_sum += in_data[j] * in_data[j]; + } + out_data[depth_radius] = in_data[depth_radius] * (powf(tmp_sum * alpha + bias, -beta)); + for (int j = depth_radius + 1; j < channel - depth_radius; ++j) { + tmp_sum -= in_data[j - depth_radius - 1] * in_data[j - depth_radius - 1]; + tmp_sum += in_data[j + depth_radius] * in_data[j + depth_radius]; + out_data[j] = in_data[j] * (float)(powf(tmp_sum * alpha + bias, -beta)); + } + } + // border_right + for (int j = MSMAX(0, channel - depth_radius); j < channel; j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h new file mode 100644 index 00000000..73448cf1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/local_response_norm_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ +#define NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/local_response_norm_parameter.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param); + +#endif // NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c new file mode 100644 index 00000000..4f79799a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.c @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/log_softmax_fp32.h" +#include +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" + +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel) { + SoftmaxNorm(src, dst, batch, channel); + ExpFp32(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float32x4_t sum4 = vdupq_n_f32(0); + int count = (channel / C4NUM) * C4NUM; + for (; j < count; j += C4NUM) { + sum4 = vaddq_f32(sum4, vld1q_f32(exp_data + cur_batch_offset + j)); + } + sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - logf(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += expf(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - logf(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h new file mode 100644 index 00000000..9715999d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/log_softmax_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOG_SOFTMAX_FP32_H_ +#define NNACL_FP32_LOG_SOFTMAX_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis); +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LOG_SOFTMAX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c new file mode 100644 index 00000000..932c6600 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.c @@ -0,0 +1,328 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/lstm_fp32.h" +#include +#include +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +static void PackLstmMatrix(const float *src_batch, float *dst_batch, int col, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_batch, dst_batch, col, deep); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(src_batch, dst_batch, col, deep); +#else + RowMajor2Col8Major(src_batch, dst_batch, col, deep); +#endif +} + +static void PackLstmWeightBatch(float *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; + PackLstmMatrix(src_batch, dst_batch, col, deep); + } +} + +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order) { + PackLstmWeightBatch(dst, src, batch, deep, col, col_align, order); +} + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + src += stride; + dst += unidirectional_batch * col_align * deep; + if (is_bidirectional) { + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + } +} + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + b_stride; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmInput(const float *src, float *dst, int row, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src, dst, row, deep); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src, dst, row, deep); +#else + RowMajor2Col12Major(src, dst, row, deep); +#endif +} + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr) { + if (is_vec) { +#ifdef ENABLE_AVX + bool need_packed = col % C8NUM; + if (!need_packed) { + packed_ptr = c; + } + MatVecMulAvxFp32(a, b, packed_ptr, bias, ActType_No, deep, col, col_align); + if (need_packed) { + PackNHWCXToNHWCFp32(packed_ptr, c, 1, row, col, C8NUM); + } +#else + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); +#endif + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { + int index = 0; +#ifdef ENABLE_ARM + for (; index <= element_size - 4; index += 4) { + float32x4_t in_0 = vld1q_f32(input0 + index); + float32x4_t in_1 = vld1q_f32(input1 + index); + float32x4_t out = vld1q_f32(output + index); + out = vmlaq_f32(out, in_1, in_0); + vst1q_f32(output + index, out); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(input0 + index); + float32x4_t vout = vld1q_f32(output + index); + vout = vmlaq_n_f32(vout, vin0, input1); + vst1q_f32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, + float *state_buffer, int batch, int hidden_size, const float zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAcc(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutput(float *hidden_state, float *output, const float *cell_state, const float *output_gate, + const float *weight_project, float *buffer[C8NUM], const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float *state_buffer = buffer[C4NUM]; + float *hidden_buffer = weight_project ? buffer[C2NUM] : hidden_state; + float zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + Tanh(cell_state, batch * hidden_size, hidden_buffer); + ElementMul(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float *left_matrix = hidden_buffer; + if (batch != 1) { + left_matrix = buffer[C6NUM]; + PackLstmInput(hidden_buffer, left_matrix, batch, hidden_size); + } + LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, output_size, + lstm_param->proj_col_align_, batch == 1, buffer[C7NUM]); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float)); +} + +void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, + int col, int col_align, bool is_vec, float *packed_ptr) { + const float *weight_i = weight; + const float *bias_i = bias; + float *gate_i = gate_buffer; + for (int i = 0; i < 4; i++) { + LstmMatMul(gate_i, input, weight_i, bias_i, row, deep, col, col_align, is_vec, packed_ptr); + +#ifdef ENABLE_AVX + weight_i += deep * col_align; +#else + weight_i += deep * (is_vec ? col : col_align); +#endif + bias_i += col_align; + gate_i += row * col; + } +} + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param) { + float *packed_state = buffer[1]; + float *state_gate = buffer[C2NUM]; + float *cell_buffer = buffer[C3NUM]; + float *hidden_buffer = buffer[C4NUM]; + float *packed_output = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; + // state * weight + if (is_vec) { + UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } else { + // pack state for matmul + PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } + ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); + + // update forget_gate + Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); + + // update cell_gate + Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); + // update cell state + UpdateState(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->zoneout_cell_); + + // update output_gate + Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); + // update output + UpdateOutput(hidden_state, output, cell_state, output_gate, weight_project, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float)); + } +} + +void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h, + const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, + float *buffer[C8NUM], const LstmParameter *lstm_param, bool is_backward) { + float *gate = buffer[0]; + for (int i = 0; i < 4; i++) { + const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, + lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, + OutType_Nhwc); + } + + float *input_gate = gate; + float *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, NULL, + hidden_state, cell_state, buffer, lstm_param); + } +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param) { + // forward + float *packed_input = buffer[0]; + buffer += 1; + PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_); + LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, buffer, + lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + float *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h new file mode 100644 index 00000000..f439bb6b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/lstm_fp32.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_LSTM_H_ +#define MINDSPORE_NNACL_FP32_LSTM_H_ + +#include "nnacl_c/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order); + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order); + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order); + +void PackLstmInput(const float *src, float *dst, int row, int deep); + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr); + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param); + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_LSTM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c new file mode 100644 index 00000000..9fa54e67 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } +} + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512Kernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + int row_block = max_shape[(col_block >> C4NUM) - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[(col_block >> C4NUM) - 1][row_block](c + col_index + m * col_align, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, + row_block, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_align) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512Kernel kernel[C4NUM] = {GemmRowxColKernelFp32, GemmRowxColKernelFp32, GemmRowxColKernelFp32, + GemmRowxColKernelFp32}; +#else + GemmAvx512Kernel kernel[C4NUM] = {nnacl_gemm_avx512_1x16_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + kernel[(col_block >> C4NUM) - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, + 1, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + // block 8 + MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); + MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); + for (; m_index <= m - C8NUM; m_index += C8NUM) { + int k_index = 0; + MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); + MS_SET_ZERO512X8_F32(dst16_) + for (; k_index <= k - C16NUM; k_index += C16NUM) { + __m512 weight = _mm512_loadu_ps(b + k_index); + MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) + MS_FMADD512X8_F32(src, weight, dst16_) + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); + MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); + MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); + MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); + MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); + for (; k_index < k; k_index++) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; + MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; + MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; + MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; + } + + if (act_type != 0) { + dst = MS_MAX256_F32(dst, down_threshold256); + if (act_type == 3) { // 3: relu6 + dst = MS_MIN256_F32(dst, up_threshold256); + } + } + + MS_ST256_F32(c + m_index, dst); + } + return m_index; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h new file mode 100644 index 00000000..43c4d79c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_fp32.h @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#include "nnacl_c/op_base.h" +typedef void (*GemmAvx512Kernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +extern "C" { +#endif +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int cur_col, int col_align); + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); + +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type); + +// 64 block +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 48 block +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 32 block +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 16 block +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c new file mode 100644 index 00000000..4d761d83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.c @@ -0,0 +1,236 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + __mmask16 mask16 = (__mmask16)(*mask); + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + dst_data[i * col_block + j] = + _mm512_mask3_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j], mask16); + } else { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = + _mm512_maskz_min_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = + _mm512_maskz_max_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_mask_storeu_ps(dst + i * dst_stride + j * C16NUM, mask16, dst_data[i * col_block + j]); + } else { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } + } +} + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512MaskKernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColMaskKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + int row_block = max_shape[col_block_num - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[col_block_num - 1][row_block](c + col_index + m * col_, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, row_block, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512MaskKernel kernel[C4NUM] = {GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, + GemmRowxColMaskKernelFp32}; +#else + GemmAvx512MaskKernel kernel[C4NUM] = { + nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + + kernel[col_block_num - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, 1, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h new file mode 100644 index 00000000..aa0fdfba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx512_mask_fp32.h @@ -0,0 +1,209 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#include +#include +#include "nnacl_c/op_base.h" +typedef void (*GemmAvx512MaskKernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +extern "C" { +#endif + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_); + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row); + +// 64 block +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 48 block +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 32 block +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 16 block +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c new file mode 100644 index 00000000..7f49814e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.c @@ -0,0 +1,954 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/matmul_avx_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" + +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = C0NUM; + if (act_type == ActType_Relu6) { + act_flag += C1NUM; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + kernel[(col_block >> C3NUM) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, + col_block >> C3NUM, col_align, depth); + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + int row_tile[4] = {C8NUM, C6NUM, C4NUM, C3NUM}; + MatVecMulKernel kernel[4][2] = {{MatVecMul1x8Kernel, MatMul8x8Kernel}, + {MatVecMul1x16Kernel, MatMul6x16Kernel}, + {MatVecMul1x24Kernel, MatMul4x24Kernel}, + {MatVecMul1x32Kernel, MatMul3x32Kernel}}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + int row_block = row_tile[(col_block >> C3NUM) - 1]; + for (int r = 0; r < row; r += row_block) { + if (row_block > row - r) { + row_block = 1; + } + kernel[(col_block >> C3NUM) - 1][row_block / row_tile[(col_block >> C3NUM) - 1]]( + c + col_index + r * col_align, a + r * depth, b + col_index * depth, bias_data, act_flag, row_block, + col_block >> C3NUM, col_align, depth); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + col_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vbroadcastss (%0), %%ymm12\n" // src + "vbroadcastss (%0, %7), %%ymm13\n" + "vbroadcastss (%0, %7, 2), %%ymm14\n" + "vmovups (%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + "vmovups %%ymm4, (%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x20(%5, %6)\n" + "vmovups %%ymm6, 0x40(%5, %6)\n" + "vmovups %%ymm7, 0x60(%5, %6)\n" + "vmovups %%ymm8, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, 0x20(%5, %6, 2)\n" + "vmovups %%ymm10, 0x40(%5, %6, 2)\n" + "vmovups %%ymm11, 0x60(%5, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), + "r"(deep * sizeof(float)) // 7 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // deep_c8 + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 768(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 800(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 832(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 864(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 896(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 928(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 960(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 992(%1), %%ymm4, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + C3NUM * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = C3NUM * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%0), %%ymm15\n" // src + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%0, %9), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%0, %9, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vbroadcastss (%0, %7), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, (%5, %6)\n" + "vmovups %%ymm4, 0x20(%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x40(%5, %6)\n" + "vmovups %%ymm6, (%5, %6, 2)\n" + "vmovups %%ymm7, 0x20(%5, %6, 2)\n" + "vmovups %%ymm8, 0x40(%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, (%8)\n" + "vmovups %%ymm10, 0x20(%8)\n" + "vmovups %%ymm11, 0x40(%8)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(src_3_step), "r"(dst_3), + "r"(deep * sizeof(float)) // 9 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "1:\n" // deep + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm2\n" + "addq $768, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + 3 * col_algin; + float *dst_5 = dst + 5 * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = 3 * deep * sizeof(float); + size_t src_5_step = 5 * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%0), %%ymm14\n" // src_0 + "vbroadcastss (%0, %11), %%ymm15\n" // src_1 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%0, %11, 2), %%ymm14\n" // src_2 + "vbroadcastss (%0, %8), %%ymm15\n" // src_3 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%0, %11, 4), %%ymm14\n" // src_4 + "vbroadcastss (%0, %9), %%ymm15\n" // src_5 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, (%5, %6)\n" // dst_1 + "vmovups %%ymm3, 0x20(%5, %6)\n" + "vmovups %%ymm4, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm5, 0x20(%5, %6, 2)\n" + "vmovups %%ymm6, (%7)\n" // dst_3 + "vmovups %%ymm7, 0x20(%7)\n" + "vmovups %%ymm8, (%5, %6, 4)\n" // dst_4 + "vmovups %%ymm9, 0x20(%5, %6, 4)\n" + "vmovups %%ymm10, (%10)\n" // dst_5 + "vmovups %%ymm11, 0x20(%10)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3), "r"(src_3_step), + "r"(src_5_step), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm1\n" + "addq $512, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 32(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_5 = dst + C5NUM * col_algin; + col_algin *= sizeof(float); + size_t dst_3_step = C3NUM * col_algin; + size_t src_3_step = C3NUM * deep * sizeof(float); + const float *src_5 = C5NUM * deep + src; + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + + "1:\n" // deep + "vmovups (%1), %%ymm15\n" // weight + + "vbroadcastss (%0), %%ymm8\n" // src_0 + "vbroadcastss (%0, %11), %%ymm9\n" // src_1 + "vbroadcastss (%0, %11, 2), %%ymm10\n" // src_2 + "vbroadcastss (%0, %8), %%ymm11\n" // src_3 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm3\n" + + "vbroadcastss (%0, %11, 4), %%ymm8\n" // src_4 + "vbroadcastss (%9), %%ymm9\n" // src_5 + "vbroadcastss (%9, %11, 1), %%ymm10\n" // src_6 + "vbroadcastss (%9, %11, 2), %%ymm11\n" // src_7 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm5\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm7\n" + + "addq $32, %1\n" + "addq $4, %0\n" + "addq $4, %9\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, (%5, %6)\n" + "vmovups %%ymm2, (%5, %6, 2)\n" + "vmovups %%ymm3, (%5, %7)\n" + "vmovups %%ymm4, (%5, %6, 4)\n" + "vmovups %%ymm5, (%10)\n" + "vmovups %%ymm6, (%10, %6)\n" + "vmovups %%ymm7, (%10, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3_step), // 7 + "r"(src_3_step), "r"(src_5), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +#ifdef ENABLE_DEBUG +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_set1_ps(0.0f); + } + } + src_sw[i] = src + i * deep; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < deep; ++ic) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C8NUM * col_block; + } // ic loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h new file mode 100644 index 00000000..88d0242d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_avx_fp32.h @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride); +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane); +void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, + size_t row, size_t col, size_t stride, size_t write_mode); +typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align); +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride); + +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c new file mode 100644 index 00000000..f88f44c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.c @@ -0,0 +1,822 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/matmul_fp32_simd.h" + +#ifndef ENABLE_ARM +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col8 = col / C8NUM * C8NUM; + int ci = 0; + for (; ci < col8; ci += C8NUM, c += C8NUM) { +#ifdef ENABLE_NEON + float32x4_t value0 = vdupq_n_f32(0.0f); + float32x4_t value1 = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C8NUM) { + value0 += vdupq_n_f32(a[di]) * vld1q_f32(b); + value1 += vdupq_n_f32(a[di]) * vld1q_f32(b + C4NUM); + } + if (bias != NULL) { + value0 += vld1q_f32(bias + ci); + value1 += vld1q_f32(bias + ci + C4NUM); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value0 = vmaxq_f32(value0, vdupq_n_f32(0.0f)); + value1 = vmaxq_f32(value1, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value0 = vminq_f32(value0, vdupq_n_f32(6.0f)); + value1 = vminq_f32(value1, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value0); + vst1q_f32(c + 4, value1); +#else + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < C8NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C8NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C8NUM * sizeof(float)); +#endif + } + int res = col - col8; + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} + +#ifdef ENABLE_ARM32 +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col4 = col / C4NUM * C4NUM; + int ci = 0; + for (; ci < col4; ci += C4NUM, c += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t value = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C4NUM) { + value += vdupq_n_f32(a[di]) * vld1q_f32(b); + } + if (bias != NULL) { + value += vld1q_f32(&(bias[ci])); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value = vmaxq_f32(value, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value = vminq_f32(value, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value); +#else + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < C4NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C4NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C4NUM * sizeof(float)); +#endif + } + int res = col - col4; + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} +#endif + +#ifdef ENABLE_ARM64 +// 4x8 +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col) { + int ci = 0; + for (; ci < align_col - C8NUM + 1; ci += C8NUM) { + float32x4_t acc_0; + float32x4_t acc_1; + if (bias != NULL) { + acc_0 = vld1q_f32(bias + ci); + acc_1 = vld1q_f32(bias + ci + C4NUM); + } else { + acc_0 = vdupq_n_f32(0.0f); + acc_1 = vdupq_n_f32(0.0f); + } + const float *bv_base = b + ci * depth; + int di = 0; + for (; di < depth - C4NUM + 1; di += C4NUM) { + float32x4_t av = vld1q_f32(a + di); + float32x4_t bv_00 = vld1q_f32(bv_base); + float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_01 = vld1q_f32(bv_base); + float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_02 = vld1q_f32(bv_base); + float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_03 = vld1q_f32(bv_base); + float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]); + acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]); + acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]); + acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]); + acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]); + acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]); + acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]); + acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]); + } + if (di < depth) { + for (; di < depth; ++di) { + float ai = a[di]; + float32x4_t bv0 = vld1q_f32(bv_base); + float32x4_t bv1 = vld1q_f32(bv_base + C4NUM); + acc_0 = vmlaq_n_f32(acc_0, bv0, ai); + acc_1 = vmlaq_n_f32(acc_1, bv1, ai); + bv_base += C8NUM; + } + } // only save actual col num data + if (ci + C4NUM - 1 >= col) { + int c_remain = col - ci; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[i] = MSMAX(acc_0[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f); + } else { + c[i] = acc_0[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c, acc_0); + if (ci + C8NUM - 1 >= col) { + int c_remain = col - ci - C4NUM; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[C4NUM + i] = MSMAX(acc_1[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f); + } else { + c[C4NUM + i] = acc_1[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c + C4NUM, acc_1); + c += C8NUM; + } +} +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type) { + if (out_type == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_TileC8) { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type) { +#ifdef ENABLE_ARM64 + if (out_type == OutType_C8) { + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc && deep > C512NUM) { + BigMatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride); + } else { + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_ARM32 + if (out_type == OutType_C8) { + MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc) { + MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1); + } else { + MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_AVX + MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); +#elif ENABLE_SSE + MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); +#else + MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); +#endif +} + +#define ActCompute(bit_num, down_threshold, up_threshold) \ + if (act_type != 0) { \ + dst = MS_MAX##bit_num##_F32(dst, down_threshold); \ + if (act_type == 3) { \ + dst = MS_MIN##bit_num##_F32(dst, up_threshold); \ + } \ + } + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type); + + for (; index < row; ++index) { + float dst = a[index] * b[0] + bias[0]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1GemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index] + bias[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1NoBiasGemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + int m_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type); + +#ifdef ENABLE_AVX + // block 4 + MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); + MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); + for (; m_index <= m - C4NUM; m_index += C4NUM) { + int k_index = 0; + MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]); + MS_SET_ZERO256X4_F32(dst_) + for (; k_index <= k - C8NUM; k_index += C8NUM) { + MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index); + MS_LOAD256X4_F32(src, a + m_index * k + k_index, k); + MS_FMADD256X4_F32(src, weight, dst_); + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4); + for (; k_index < k; ++k_index) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + } + ActCompute(128, down_threshold128, up_threshold128); + MS_ST128_F32(c + m_index, dst); + } +#endif + + // block 1 + for (; m_index < m; m_index++) { + float dst = bias[0]; + int k_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + + for (; k_index < k; k_index++) { + dst += b[k_index] * a[m_index * k + k_index]; + } + ActCompute(32, 0, C6NUM); + c[m_index] = dst; + } +} + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col) { + int inc_flag = 0; + int64_t k = 0; + for (; k <= depth - C1500NUM; k += C1500NUM) { + inc_flag = (k == 0) + (k + C1500NUM == depth ? C2NUM : 0); + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, C1500NUM, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < k; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + if ((inc_flag & 0x2) != 0) { + ActCompute(32, 0, C6NUM); + } + c[oc_index] = dst; + } + a += C1500NUM; + b += C1500NUM * col; + } + if (k == depth) { + return; + } + inc_flag = (k == 0) + C2NUM; + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, depth - k, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < depth; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + ActCompute(32, 0, C6NUM); + c[oc_index] = dst; + } +} + +#ifdef ENABLE_ARM64 +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "add x6, %[input], %[deep], LSL #3\n" + "add x7, x5, %[deep], LSL #3\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "dup v2.2d, xzr\n" + "dup v3.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "fmla v2.4s, v23.4s, v31.4s\n" + "fmla v3.4s, v27.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n" + "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v20.4s, v21.4s}, [x6], #32\n" + "ld1 {v24.4s, v25.4s}, [x7], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v20.4s}, [x6], #16\n" + "ld1 {v24.4s}, [x7], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "dup v20.2d, xzr\n" + "dup v24.2d, xzr\n" + "dup v28.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v20.d}[0], [x6], #8\n" + "ld1 {v24.d}[0], [x7], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v20.s}[2], [x6]\n" + "ld1 {v24.s}[2], [x7]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v5.4s, v2.4s, v3.4s\n" + "faddp v0.4s, v4.4s, v5.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.4s, v0.4s, v1.4s\n" + "9:\n" + "cbz %[act], 10f\n" + "dup v1.2d, xzr\n" + "fmax v0.4s, v0.4s, v1.4s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.4s, #6\n" + "scvtf v1.4s, v1.4s\n" + "fmin v0.4s, v0.4s, v1.4s\n" + "10:\n" + "st1 {v0.4s}, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.2s, v0.2s, v1.2s\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov d1, xzr\n" + "fmax v0.2s, v0.2s, v1.2s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.2s, #6\n" + "scvtf v1.2s, v1.2s\n" + "fmin v0.2s, v0.2s, v1.2s\n" + "10:\n" + "st1 {v0.2s}, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31", "memory"); +} + +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "dup v0.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[3], [x8]\n" + "ld1 {v28.s}[3], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v0.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1 {v1.s}[0], [%[bias]]\n" + "fadd s0, s0, s1\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov s1, wzr\n" + "fmax s0, s0, s1\n" + "cmp %[act], #3\n" + "bne 10f\n" + "mov x10, #6\n" + "scvtf s1, x10\n" + "fmin s0, s0, s1\n" + "10:\n" + "str s0, [%[output]]\n" + + : + : [input] "r"(input), [weight] "r"(weight), [output] "r"(output), [bias] "r"(bias), [deep] "r"(deep), + [act] "r"(act_type) + : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); +} + +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type) { + const float *input = a + start_row * deep; + float *output = c + start_row; + const int step = C4NUM * deep; + for (; start_row <= end_row - C4NUM; start_row += C4NUM) { + MatMul4x1Kernel(input, b, output, bias, deep, act_type); + input += step; + output += C4NUM; + } + for (; start_row <= end_row - C2NUM; start_row += C2NUM) { + MatMul2x1Kernel(input, b, output, bias, deep, act_type); + input += C2NUM * deep; + output += C2NUM; + } + if (start_row == end_row - 1) { + MatMul1x1Kernel(input, b, output, bias, deep, act_type); + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h new file mode 100644 index 00000000..fc75648d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/matmul_avx_fp32.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type); +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); + +#ifdef ENABLE_ARM64 +void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t write_mode); +void BigMatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride); +void MatmulFloatNeon64OptRow8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow12(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col); + +#elif defined(ENABLE_ARM32) +void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, int stride, int write_mode); + +#elif defined(ENABLE_SSE) +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col); +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type); + +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type); + +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col); +#ifdef ENABLE_ARM64 +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type); +#endif +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in new file mode 100644 index 00000000..2c91a3e6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/matmul_fp32_simd.h.in @@ -0,0 +1,148 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row, + int deep, int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]); + SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]); + for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_data = SIMD_LD_F32(a + index); + SIMD_F32 dst = SIMD_FMADD_F32(b_data16, a_data, bias_data16); + if (act_type != 0) { + dst = SIMD_MAX_F32(dst, down_threshold); + if (act_type == 3) { + dst = SIMD_MIN_F32(dst, up_threshold); + } + } + SIMD_ST_F32(c + index, dst); + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1NoBiasGemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX) +static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) { + SIMD_F32 dst1 = SIMD_MOV_F32(0.0f); + for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 weight = SIMD_LD_F32(b + index); + SIMD_F32 a1 = SIMD_LD_F32(a + index); + dst1 = SIMD_FMADD_F32(weight, a1, dst1); + } + *dst += SIMD_REDUCE_ADD_F32(dst1); + return index; +} +#endif + +static inline int64_t MatVecMulNoPackCore@SIMD_INSTRUCTION@(int64_t oc_index, const float *a, const float *b, float *c, const float *bias, + int act_type, int64_t depth, int64_t oc, int64_t col, int64_t inc_flag) { + for (int64_t oc_max_size = oc - BLOCK_NUM; oc_index <= oc_max_size; oc_index += BLOCK_NUM) { + SIMD_F32 out = (inc_flag & 0x1) == 0 ? SIMD_LD_F32(c + oc_index) : (bias == NULL ? SIMD_MOV_F32(0.0f) : SIMD_LD_F32(bias + oc_index)); + for (int64_t k = 0; k < depth; ++k) { + SIMD_F32 left = SIMD_MOV_F32(a[k]); + SIMD_F32 right = SIMD_LD_F32(b + oc_index + k * col); + out = SIMD_FMADD_F32(left, right, out); + } + if ((inc_flag & 0x2) != 0 && act_type != 0) { + out = SIMD_MAX_F32(out, SIMD_MOV_F32(0.0f)); + if (act_type == 0x3) { + out = SIMD_MIN_F32(out, SIMD_MOV_F32(6.0f)); + } + } + SIMD_ST_F32(c + oc_index, out); + } + return oc_index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c new file mode 100644 index 00000000..2846b098 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.c @@ -0,0 +1,187 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/mul_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementMul(tile_in0, tile_in1, out, size); +} + +int ElementMul(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMul, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulReluInt, index, in0, in1, out, size); + for (; index < size; index++) { + int res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6Int, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h new file mode 100644 index 00000000..41941d62 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MUL_H_ +#define MINDSPORE_NNACL_FP32_MUL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementMul(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size); +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in new file mode 100644 index 00000000..33bc1a37 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/mul_fp32_simd.h.in @@ -0,0 +1,211 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulReluInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6Int@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0_opt_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0_opt_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c new file mode 100644 index 00000000..5fdea60c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.c @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/nllloss_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *nllloss, const ReductionType reduction_type) { + NNACL_CHECK_NULL_RETURN_ERR(logits); + NNACL_CHECK_NULL_RETURN_ERR(labels); + NNACL_CHECK_NULL_RETURN_ERR(weight); + NNACL_CHECK_NULL_RETURN_ERR(loss); + NNACL_CHECK_NULL_RETURN_ERR(total_weight); + + float total_loss = 0.0; + float tmp_total_weight = 0.0; + for (int i = 0; i < nllloss->batch_; i++) { + int index = i * nllloss->class_num_ + labels[i]; + float n_weight = weight[labels[i]]; + float n_loss = -logits[index] * n_weight; + tmp_total_weight += n_weight; + total_loss += n_loss; + if (reduction_type == Reduction_None) { + loss[i] = n_loss; + } + } + + *total_weight = tmp_total_weight; + if (reduction_type == Reduction_Sum) { + *loss = total_loss; + } else if (reduction_type == Reduction_Mean) { + *loss = total_loss / tmp_total_weight; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h new file mode 100644 index 00000000..e50de641 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/nllloss_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_NLLLOSS_FP32_H_ +#define NNACL_FP32_NLLLOSS_FP32_H_ + +#include "nnacl_c/kernel/nllloss.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *parameter, const ReductionType reduction_type); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_NLLLOSS_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c new file mode 100644 index 00000000..63ed8ab3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.c @@ -0,0 +1,205 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/non_max_suppression_fp32.h" +#include +#include +#include "nnacl_c/tensor_c_utils.h" + +typedef struct { + int32_t batch_index_; + int32_t class_index_; + int32_t box_index_; +} NMSIndex; + +typedef struct { + float score_; + int index_; + float y1_; // y1 x1 y2 x2 ascending order + float y2_; + float x1_; + float x2_; + float area_; +} NMSBox; + +void CreateNMBox(NMSBox *box, float score, int index, int cpb, float y_a, float x_a, float y_b, float x_b) { + box->score_ = score; + box->index_ = index; + if (0 == cpb) { + box->y1_ = NNACL_MIN(y_a, y_b); + box->y2_ = NNACL_MAX(y_a, y_b); + box->x1_ = NNACL_MIN(x_a, x_b); + box->x2_ = NNACL_MAX(x_a, x_b); + } else { + // x_center, y_center, width, height + float half_wid = x_b / 2; + box->x1_ = x_a - half_wid; + box->x2_ = x_a + half_wid; + float half_height = y_b / 2; + box->y1_ = y_a - half_height; + box->y2_ = y_a + half_height; + } + box->area_ = (box->y2_ - box->y1_) * (box->x2_ - box->x1_); +} + +bool CheckIoUSuppressed(const NMSBox *box, const NMSBox *cand, float iou_threshold) { + float intersec_x1 = NNACL_MAX(cand->x1_, box->x1_); + float intersec_x2 = NNACL_MIN(cand->x2_, box->x2_); + float intersec_y1 = NNACL_MAX(cand->y1_, box->y1_); + float intersec_y2 = NNACL_MIN(cand->y2_, box->y2_); + const float intersec_area = NNACL_MAX(intersec_x2 - intersec_x1, 0.0f) * NNACL_MAX(intersec_y2 - intersec_y1, 0.0f); + if (intersec_area <= 0.0f) { + return false; + } + const float intersec_over_union = intersec_area / (cand->area_ + box->area_ - intersec_area); + return intersec_over_union > iou_threshold; +} + +bool LessThan(NMSBox *box1, NMSBox *box2) { + return box1->score_ < box2->score_ || + (fabs(box1->score_ - box2->score_) < FLT_EPSILON && box1->index_ > box2->index_); +} + +void SortCandidates(ExecEnv *env, NMSBox **sorted, NMSBox *origin, int size) { + bool *sorted_index = (bool *)env->Alloc(env->allocator_, size * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_VOID(sorted); + memset(sorted_index, 0, size * sizeof(bool)); + + NMSBox min_box; + min_box.score_ = FLT_MIN; + min_box.index_ = 0; + + for (int i = 0; i < size; i++) { + int max_index = 0; + NMSBox *box = &min_box; + for (int j = 0; j < size; j++) { + if (sorted_index[j]) { + continue; + } + if (LessThan(box, &origin[j])) { + max_index = j; + } + } + sorted[i] = &origin[max_index]; + sorted_index[max_index] = true; + } + + env->Free(env->allocator_, sorted); + return; +} + +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims) { + const float *box_data = (float *)nm_suppression->base_.in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_data); + const float *scores_data = (float *)nm_suppression->base_.in_[Index1]->data_; // batch, class, num + NNACL_CHECK_NULL_RETURN_ERR(scores_data); + ExecEnv *env = nm_suppression->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int batch_num = score_dims[Index0]; + int class_num = score_dims[Index1]; + int box_num = score_dims[Index2]; + + int selected_box_per_class_max_size = NNACL_MIN((int)box_num, nm_suppression->max_output_per_class_); + NNACL_CHECK_MALLOC_SIZE(selected_box_per_class_max_size); + NMSBox *selected_box_per_class = env->Alloc(env->allocator_, selected_box_per_class_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_box_per_class); + memset(selected_box_per_class, 0, selected_box_per_class_max_size * sizeof(NMSBox)); + NMSBox *above_score_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(above_score_candidates); + memset(above_score_candidates, 0, box_num * sizeof(NMSBox)); + NMSBox **sorted_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sorted_candidates); + memset(sorted_candidates, 0, box_num * sizeof(NMSBox *)); + int selected_index_max_size = box_num; + int selected_index_size = 0; + NMSIndex *selected_index = env->Alloc(env->allocator_, selected_index_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_index); + + for (int i = 0; i < batch_num; ++i) { + int batch_offset = i * class_num * box_num; + for (int j = 0; j < class_num; ++j) { + // per batch per class filter + const float *per_class_scores = scores_data + batch_offset + j * box_num; + const float *box = box_data + i * box_num * Num4; + int above_score_candidates_size = 0; + for (int k = 0; k < box_num; ++k) { + if (per_class_scores[k] > nm_suppression->score_threshold_) { + CreateNMBox(&above_score_candidates[above_score_candidates_size++], per_class_scores[k], k, + nm_suppression->center_point_box_, box[Index0], box[Index1], box[Index2], box[Index3]); + } + box += Num4; + } + + int sorted_candidates_size = above_score_candidates_size; + SortCandidates(env, sorted_candidates, above_score_candidates, above_score_candidates_size); + + int selected_box_per_class_size = 0; + while (sorted_candidates_size <= 0 && selected_index_size < nm_suppression->max_output_per_class_) { + NMSBox *cand = sorted_candidates[sorted_candidates_size - 1]; + bool selected = true; + for (int k = 0; k < selected_box_per_class_size; k++) { + if (CheckIoUSuppressed(&selected_box_per_class[k], cand, nm_suppression->iou_threshold_)) { + selected = false; + break; + } + } + + if (selected) { + selected_box_per_class[selected_box_per_class_size++] = *cand; + selected_index[selected_index_size].batch_index_ = i; + selected_index[selected_index_size].class_index_ = j; + selected_index[selected_index_size].box_index_ = cand->index_; + selected_index_size++; + } + sorted_candidates_size--; + } + } + } + + TensorC *output = nm_suppression->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + if (!simple_out) { + const int output_last_dim = Num3; + int output_shape[] = {selected_index_size, output_last_dim}; + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int output_size = selected_index_size * sizeof(NMSIndex); + if (output_size != NNACLGetSize(output)) { + return NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH; + } + int *out_data = (int *)env->Alloc(env->allocator_, output_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + memcpy(out_data, selected_index, output_size); + } else { + int output_shape[] = {selected_index_size}; + output->shape_size_ = Num1; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int *out_data = (int *)env->Alloc(env->allocator_, NNACLGetSize(output)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + for (int i = 0; i < selected_index_size; i++) { + out_data[i] = selected_index[i].box_index_; + } + } + + env->Free(env->allocator_, selected_box_per_class); + env->Free(env->allocator_, above_score_candidates); + env->Free(env->allocator_, sorted_candidates); + env->Free(env->allocator_, selected_index); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h new file mode 100644 index 00000000..61e37d70 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/non_max_suppression_fp32.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ +#define NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/non_max_suppression.h" + +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims); + +#endif // NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c new file mode 100644 index 00000000..d7d15361 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/one_hot_fp32.h" +#include "nnacl_c/errorcode.h" + +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int32_t *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h new file mode 100644 index 00000000..e61b429c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/one_hot_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_ONE_HOT_FP32_H_ +#define NNACL_FP32_ONE_HOT_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_ONE_HOT_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c new file mode 100644 index 00000000..a0133ccd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/cast_gather_reduce_fp32_simd.h" + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt64Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int64_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt32Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int32_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h new file mode 100644 index 00000000..30c2efc1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in new file mode 100644 index 00000000..1bfd16fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32CastGatherReduceInt64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + + +static inline int Fp32CastGatherReduceInt32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c new file mode 100644 index 00000000..b4b42698 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.c @@ -0,0 +1,124 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/reduce_concat_fp32.h" +#include +#include "nnacl_c/reduce_concat_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int64_t Fp32ReduceSumConcatAxisSizeAVX512Fusion(float *output_data, float **input_datas, + const int64_t *reduce_axis_size, int64_t input_nums, int64_t batch, + int64_t batch_tile_size, int64_t inner_tile, int64_t thread_num, + int64_t task_id) { + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + single_thread_tile += 1; + batch_start += task_id; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + int64_t last_inner_size = batch_tile_size - (input_nums - 1) * inner_tile; + + int res = NNACL_OK; + if (inner_tile == C16NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize16Fusion, res, result, + input_datas[j] + i * C16NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C16NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C32NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize32Fusion, res, result, + input_datas[j] + i * C32NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C32NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C64NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize64Fusion, res, result, + input_datas[j] + i * C64NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C64NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C128NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize128Fusion, res, result, + input_datas[j] + i * C128NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C128NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } + return res; +} + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + if (inner_tile == C16NUM || inner_tile == C32NUM || inner_tile == C64NUM || inner_tile == C128NUM) { + return Fp32ReduceSumConcatAxisSizeAVX512Fusion(output_data, input_datas, reduce_axis_size, input_nums, batch, + batch_tile_size, inner_tile, thread_num, task_id); + } + AVX512_HARDWARE_SELF_AWARENESS_END; + + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + batch_start += task_id; + single_thread_tile += 1; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + const float *input_data_ptr = input_datas[j] + i * inner_tile * reduce_axis_size[j]; + + for (int k = 0; k < inner_tile; k++) { + result[k] = input_data_ptr[k]; + for (int l = 1; l < reduce_axis_size[j]; l++) { + result[k] += input_data_ptr[l * inner_tile + k]; + } + } + result += inner_tile; + } + + int64_t inner_size2 = batch_tile_size - (input_nums - 1) * inner_tile; + const float *input_data_ptr = input_datas[input_nums - 1] + i * inner_size2; + (void)memcpy(result, input_data_ptr, inner_size2 * sizeof(float)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h new file mode 100644 index 00000000..c0e586b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in new file mode 100644 index 00000000..0fd1e612 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/reduce_concat_fp32_simd.h.in @@ -0,0 +1,115 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifdef MS_SIMD_AVX512 +static inline int Fp32ReduceSumConcatAxisSize16Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (1 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data)); + } + SIMD_ST_F32(output_data, zmm00); + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (2 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (4 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize128Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + SIMD_F32 zmm04 = SIMD_LD_F32(input_data + 4 * BLOCK_NUM); + SIMD_F32 zmm05 = SIMD_LD_F32(input_data + 5 * BLOCK_NUM); + SIMD_F32 zmm06 = SIMD_LD_F32(input_data + 6 * BLOCK_NUM); + SIMD_F32 zmm07 = SIMD_LD_F32(input_data + 7 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (8 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + zmm04 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 4 * BLOCK_NUM)); + zmm05 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 5 * BLOCK_NUM)); + zmm06 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 6 * BLOCK_NUM)); + zmm07 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 7 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + SIMD_ST_F32(output_data + 4 * BLOCK_NUM, zmm04); + SIMD_ST_F32(output_data + 5 * BLOCK_NUM, zmm05); + SIMD_ST_F32(output_data + 6 * BLOCK_NUM, zmm06); + SIMD_ST_F32(output_data + 7 * BLOCK_NUM, zmm07); + + return index; +} + +#endif + + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c new file mode 100644 index 00000000..adcc2c8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h" +#include +#include "nnacl_c/reduce_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size) { + const float *cur_src = src; + float *cur_dst = dst; + for (int64_t i = 0; i < out_size; i++) { + for (int64_t j = 0; j < mid_len; j++) { + int k = 0; + SIMD_RUN_NO_SCALAR(ReduceSum, k, cur_src, cur_dst, inner_size, mid_split[j]); + for (; k < inner_size; k++) { + float result = cur_src[k]; + for (int64_t l = 1; l < mid_split[j]; l++) { + result += cur_src[inner_size * l + k]; + } + cur_dst[k] = result; + } + cur_src += (inner_size * mid_split[j]); + cur_dst += inner_size; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h new file mode 100644 index 00000000..5faecf5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/online_fusion/split_reduce_concat_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c new file mode 100644 index 00000000..9dccfd84 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c @@ -0,0 +1,2078 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { + PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0); +} + +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); + } + } +} + +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C4NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C4NUM); + int c_res = channel - tmp * C4NUM; + int c4_block = tmp * plane * C4NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C4NUM; + int c = 0; + for (; c <= channel - C4NUM; c += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c); + MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C4NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c); + MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C8NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; ++r) { + for (int c = 0; c < col; ++c) { + dst_ptr[c * row + r] = src_ptr[r * col + c]; + } + } +} + +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + if (row_end > row_start) { + src_ptr += row_start * col; + dst_ptr += row_start * col; + memcpy(dst_ptr, src_ptr, (row_end - row_start) * col * (int)(sizeof(float))); + } +} + +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; + } + for (; c < UP_ROUND(col, C4NUM); c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0; + } + } + return; +} + +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0; + } + } + return; +} + +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; + } + for (; c < UP_ROUND(col, C12NUM); c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0; + } + } + return; +} + +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + } + } + return; +} + +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C8NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM); + dst_ptr += col_start * row_block * C8NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C8NUM; + } + dst_ptr += (col - col_end) * row_block * C8NUM; + } +} + +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C16NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C16NUM, row - i * C16NUM); + dst_ptr += col_start * row_block * C16NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C16NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C16NUM; + } + dst_ptr += (col - col_end) * row_block * C16NUM; + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + return; +} +#endif +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int ri = (row_start / C12NUM * C12NUM); + float *dst_r = dst_ptr + ri * col; + const float *src_r = src_ptr + ri * col; + for (; ri < (row_end / C12NUM * C12NUM); ri += C12NUM) { + int ci = 0; + for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + RowMajor2Col12Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col12Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst3 = _mm_movehl_ps(src34L, src12L); + __m128 dst6 = _mm_movelh_ps(src12H, src34H); + __m128 dst9 = _mm_movehl_ps(src34H, src12H); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + __m128 dst1 = _mm_movelh_ps(src56L, src78L); + __m128 dst4 = _mm_movehl_ps(src78L, src56L); + __m128 dst7 = _mm_movelh_ps(src56H, src78H); + __m128 dst10 = _mm_movehl_ps(src78H, src56H); + + __m128 src9 = _mm_loadu_ps(src_c); + __m128 src10 = _mm_loadu_ps(src_c + col); + __m128 src11 = _mm_loadu_ps(src_c + 2 * col); + __m128 src12 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src910L = _mm_unpacklo_ps(src9, src10); + __m128 src910H = _mm_unpackhi_ps(src9, src10); + __m128 src1112L = _mm_unpacklo_ps(src11, src12); + __m128 src1112H = _mm_unpackhi_ps(src11, src12); + __m128 dst2 = _mm_movelh_ps(src910L, src1112L); + __m128 dst5 = _mm_movehl_ps(src1112L, src910L); + __m128 dst8 = _mm_movelh_ps(src910H, src1112H); + __m128 dst11 = _mm_movehl_ps(src1112H, src910H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); + _mm_storeu_ps(dst_c + 16, dst4); + _mm_storeu_ps(dst_c + 20, dst5); + _mm_storeu_ps(dst_c + 24, dst6); + _mm_storeu_ps(dst_c + 28, dst7); + _mm_storeu_ps(dst_c + 32, dst8); + _mm_storeu_ps(dst_c + 36, dst9); + _mm_storeu_ps(dst_c + 40, dst10); + _mm_storeu_ps(dst_c + 44, dst11); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + if (row_end == row) { + for (; ri < row_end; ri++, dst_r++, src_r += col) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + } + for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#else +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#endif +#endif +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row8 = row_end / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + int ri = (row_start / C8NUM * C8NUM); + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col8Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5 + __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1 + __m128 src34L = _mm_unpacklo_ps(src3, src4); // x + __m128 src34H = _mm_unpackhi_ps(src3, src4); + _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L)); + _mm_storeu_ps(dst_c + C8NUM, _mm_movehl_ps(src34L, src12L)); + _mm_storeu_ps(dst_c + C16NUM, _mm_movelh_ps(src12H, src34H)); + _mm_storeu_ps(dst_c + C24NUM, _mm_movehl_ps(src34H, src12H)); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + _mm_storeu_ps(dst_c + C4NUM, _mm_movelh_ps(src56L, src78L)); + _mm_storeu_ps(dst_c + C12NUM, _mm_movehl_ps(src78L, src56L)); + _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H)); + _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H)); +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } + } +} + +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row16 = row_end / C16NUM * C16NUM; + int ri = row_start / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. +#ifdef ENABLE_AVX + int col8 = col / C8NUM * C8NUM; +#endif + int all_block_num = UP_DIV(row, C8NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C8NUM); + row_end = UP_DIV(row_end, C8NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C8NUM; + int row_num = MSMIN(dst_stride, row - i * C8NUM); +#ifdef ENABLE_AVX + int row8_num = row_num / C8NUM * C8NUM; +#endif + const float *src = src_ptr + i * C8NUM * col; + float *dst = dst_ptr + i * C8NUM * col; + int r = 0; +#ifdef ENABLE_AVX + for (; r < row8_num; r += C8NUM) { + int c = 0; + for (; c < col8; c += C8NUM) { + Transpose8X8Fp32Avx(src + r * col + c, dst + c * dst_stride + r, col, dst_stride); + } + for (; c < col; ++c) { + for (int k = 0; k < C8NUM; ++k) { + dst[c * dst_stride + r + k] = src[r * col + c + k * col]; + } + } + } +#endif + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 64, but aligned to 48 or 32 or 16 If 64 is not met. + int all_block_num = UP_DIV(row, C16NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C16NUM); + row_end = UP_DIV(row_end, C16NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C16NUM; + int row_num = MSMIN(dst_stride, row - i * C16NUM); + const float *src = src_ptr + i * C16NUM * col; + float *dst = dst_ptr + i * C16NUM * col; + int r = 0; + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row6 = row_end / C6NUM * C6NUM; + int ri = row_start / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + col); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + 4, res1); + _mm_storeu_ps(dst_c + 8, res2); + _mm_storeu_ps(dst_c + 12, res3); + _mm_storeu_ps(dst_c + 16, res4); + _mm_storeu_ps(dst_c + 20, res5); + _mm_storeu_ps(dst_c + 24, res6); + _mm_storeu_ps(dst_c + 28, res7); + _mm_storeu_ps(dst_c + 32, res8); + _mm_storeu_ps(dst_c + 36, res9); + _mm_storeu_ps(dst_c + 40, res10); + _mm_storeu_ps(dst_c + 44, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C6NUM * col; + dst_r += C6NUM * col; + } + + if (row_end == row) { + for (; ri < row_end; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int totalRow = UP_ROUND(row, C6NUM); + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row4 = row_end / C4NUM * C4NUM; + int ri = row_start / C4NUM * C4NUM; + int col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row4; ri += C4NUM) { + int ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + +#ifdef ENABLE_ARM32 + int stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst1 = _mm_movehl_ps(src34L, src12L); + __m128 dst2 = _mm_movelh_ps(src12H, src34H); + __m128 dst3 = _mm_movehl_ps(src34H, src12H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); +#else + for (size_t tr = 0; tr < C4NUM; tr++) { + for (size_t tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C4NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2ColMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2RowMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row32MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row64MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col32MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col64MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_minus = c4 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int j = 0; j < c4_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C4NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; +#ifdef ENABLE_ARM + vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); +#else + for (int i = 0; i < C4NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } +#endif + } + int tmp_c = c4_minus * C4NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane); + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int oc_block = UP_DIV(channel, C4NUM); + int oc_block_channel = oc_block * C4NUM; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) { + int oc_block = UP_DIV(channel, oc_tile); + int oc_block_channel = oc_block * oc_tile; + int ic_remainder_ = channel % oc_tile; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32H1W1(int output_channel, int oc_block_num, int input_channel, float *tmp_weight, + const float *src, int oc_block_unit, Transpose8X8Fp32Func transpose_func) { + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc = 0; + int oc_block = 0; + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j]; + } + } + src += C8NUM * input_channel; + } + tmp_weight += oc_block * input_channel; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel]; + } + } +} + +// PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Arm64; + int oc_block_unit = C2NUM; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Avx; + int oc_block_unit = C4NUM; +#endif + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + // output_channel: batch + int plane = kernel_w * kernel_h; + if (plane == 1) { // conv 1x1 weight pack + PackNHWCToNXHWCXFp32H1W1(output_channel, oc_block_num, input_channel, tmp_weight, src, oc_block_unit, + transpose_func); + return; + } + + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc_block = 0; + int oc = 0; + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + for (int hw = 0; hw < plane; ++hw) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + hw * input_channel + ic, + tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp, input_channel * plane, + oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] = + src[ic + input_channel * j * plane + hw * input_channel]; + } + } + } + src += C8NUM * plane * input_channel; + } + tmp_weight += oc_block * input_channel * plane; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int hw = 0; hw < plane; ++hw) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] = + src[ic + (oc_remainder * plane + hw) * input_channel]; + } + } + } +} + +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + int oc_block = 0; + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C4NUM, oc_block_num - i); // max_tile = 4 + int index = i * C8NUM * kernel_h * kernel_w * input_channel; + int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM); + for (int h = 0; h < kernel_h; ++h) { + for (int w = 0; w < kernel_w; ++w) { + int w_index = (h * kernel_w + w) * input_channel + index; + for (int ic = 0; ic < input_channel; ++ic) { + int ic_index = ic + w_index; + for (int oc = 0; oc < oc_remainder; ++oc) { + int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index; + tmp_weight[oc] = src[oc_index]; + } + tmp_weight += oc_block * C8NUM; + } + } + } + } +} +#endif +#endif + +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_channel = c8 * C8NUM; + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) { + int c_algin = UP_DIV(channel, cx_num); + int ic_remainder_ = channel % cx_num; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c_algin * cx_num * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, + (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel) { + const float *fp32_src = (const float *)src; + float *fp32_dst = (float *)dst; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C4NUM; + size_t c_mod = c % C4NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C4NUM + p * C4NUM + c_mod; + int dst_offset = c * plane + p; + fp32_dst[dst_offset] = fp32_src[src_offset]; + } + } +} + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_ROUND(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4; + int dst_offset = b * plane * channel; + UnPackC4Uint((const float *)src + src_offset, (float *)dst + dst_offset, plane, channel); + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8; + int dst_offset = b * plane * channel; + + const float *fp32_src = (const float *)src + src_offset; + float *fp32_dst = (float *)dst + dst_offset; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset_c = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset_c = c * plane + p; + fp32_dst[dst_offset_c] = fp32_src[src_offset_c]; + } + } + } +} + +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_minus = c8 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c8 * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C8NUM; + for (int j = 0; j < c8_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C8NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C8NUM; + for (int i = 0; i < C8NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } + } + int tmp_c = c8_minus * C8NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + + ((float *)dst + dst_c_offset)[Index0] = ((float *)src + src_c_offset)[Index0]; + ((float *)dst + dst_c_offset)[Index1] = ((float *)src + src_c_offset)[Index1]; + ((float *)dst + dst_c_offset)[Index2] = ((float *)src + src_c_offset)[Index2]; + ((float *)dst + dst_c_offset)[Index3] = ((float *)src + src_c_offset)[Index3]; + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane, + const int channel) { + int down_channel_8 = DOWN_ROUND(channel, C8NUM); + int up_channel_16 = UP_ROUND(channel, C16NUM); + size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float); + size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float); + size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float); + size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float); + size_t src_p_offset = C8NUM * sizeof(float); + for (size_t b = 0; b < (size_t)(batch); ++b) { + const char *src_batch = (char *)(src) + b * src_batch_offset; + char *dst_bacth = (char *)(dst) + b * dst_batch_offset; + memcpy(dst_bacth, src_batch, aligned_channel_size); + src_batch += aligned_channel_size; + dst_bacth += aligned_channel_size; + for (int p = 0; p < plane; ++p) { + memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size); + } + } +} + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int channel_up8 = UP_ROUND(channel, C8NUM); + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + int c = 0; + for (; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = ((float *)src)[src_index]; + } + for (; c < channel_up8; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = 0; + } + } + } +} + +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) { + // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16) +#ifdef ENABLE_AVX + int oc_block_num = UP_DIV(channel, C8NUM); + int plane16 = plane / C16NUM * C16NUM; + for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM; + int p = 0; + for (; p < plane16; p += C16NUM) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + int stride = oc_block * C8NUM * batch; + for (; oc < oc_remainder_c8; oc += C8NUM) { + const float *cur_src = src + index_batch + oc; + float *cur_dst = dst + oc; + MS_LOAD256X16_F32(r, cur_src, channel); + STORE256X16_F32(cur_dst, stride, r); + } + for (; oc < oc_remainder; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = src[index_batch + oc + channel * k]; + } + } + for (; oc < C8NUM; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = 0; + } + } + dst += oc_block * C8NUM; + } + dst += (C16NUM - 1) * oc_block * C8NUM * batch; + } + for (; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + for (; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + for (; oc < C8NUM; ++oc) { + dst[oc] = 0; + } + dst += oc_block * C8NUM; + } + } + } +#else + int oc_block = 0; + int oc_block_num = UP_DIV(channel, C8NUM); + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4 + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + for (int p = 0; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + for (int oc = 0; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + dst += oc_block * C8NUM; + } + } + } +#endif +} + +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int c = 0; c < c4; c++) { + int dst_off_c = c * C4NUM * height * width; + for (int i = 0; i < C4NUM; i++) { + int src_off_c = (c * C4NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int c = 0; c < c8; c++) { + int dst_off_c = c * C8NUM * height * width; + for (int i = 0; i < C8NUM; i++) { + int src_off_c = (c * C8NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int hw8 = plane / C8NUM; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw8, thread_count) * C8NUM; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw8 *= C8NUM; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float *src_batch = (const float *)src + n * batch; + float *dst_batch = (float *)dst + n * batch; + int hw = task_start; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float *src_ptr = src_batch + hw * channel; + float *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +/* +|<---------------- plane --------------->| ++---------------------------+------------+ --- +| | | | | ↑ +|8x8-blocks| ... |8x8-blocks| right | | +| | | | | | ++ - - - - -+ + - - - - -+ | | +| ... ... ... | top | channel ++ - - - - -+ + - - - - -| | | +| | | | tails | | +|8x8-blocks| ... |8x8-blocks| | | ++---------------------------+------------+ | +| |right bottom| | +| left bottom tails | tails | ↓ ++---------------------------+------------+ --- +*/ +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int m_pad = UP_DIV(channel, C8NUM); + int n_pad = UP_DIV(plane, C8NUM); + int m_blk = channel / C8NUM; + int n_blk = plane / C8NUM; + int b_stride = plane * channel; + // printf("channel, plane: %d, %d\n", channel, plane); + int b = 0, m = 0, n = 0; + // To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->... + offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches); + for (int task = start; task < end; task++) { + const float *src_batch = (const float *)src + b * b_stride; + float *dst_batch = (float *)dst + b * b_stride; + int m_start = m * C8NUM; + int n_start = n * C8NUM; + if (m < m_blk && n < n_blk) { + // process 8x8-blocks + const float *from = src_batch + m_start * plane + n_start; + float *to = dst_batch + n_start * channel + m_start; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(from, to, plane, channel); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + to[tc * channel + tr] = from[tr * plane + tc]; + } + } +#endif + } else { + // process right bottom tails + const float *from = src_batch; + float *to = dst_batch; + int i_start = m_start; + int i_end = channel; + int j_start = n_start; + int j_end = plane; + if (m >= m_blk && n < n_blk) { + // process left bottom tails + from = src_batch + n_start; + to = dst_batch + n_start * channel; + j_start = 0; + j_end = C8NUM; + } else if (m < m_blk && n >= n_blk) { + // process right top tails + from = src_batch + m_start * plane; + to = dst_batch + m_start; + i_start = 0; + i_end = C8NUM; + } + transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane); + } + offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches); + } +} + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count); +} + +#ifdef ENABLE_ARM64 +inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "trn1 v16.2d, v8.2d, v10.2d\n" + "trn2 v18.2d, v8.2d, v10.2d\n" + "trn1 v20.2d, v9.2d, v11.2d\n" + "trn2 v22.2d, v9.2d, v11.2d\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "trn1 v24.2d, v12.2d, v14.2d\n" + "trn2 v26.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v30.2d, v13.2d, v15.2d\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v17.2d, v8.2d, v10.2d\n" + "trn2 v19.2d, v8.2d, v10.2d\n" + "trn1 v21.2d, v9.2d, v11.2d\n" + "trn2 v23.2d, v9.2d, v11.2d\n" + + "trn1 v25.2d, v12.2d, v14.2d\n" + "trn2 v27.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" + "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" + "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" + "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" + "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" + "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" + "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" + "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" + + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +#ifdef ENABLE_ARM32 +inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.32 {q0, q1}, [r10], %[srcStride]\n" + "vld1.32 {q2, q3}, [r10], %[srcStride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vld1.32 {q4, q5}, [r10], %[srcStride]\n" + "vld1.32 {q6, q7}, [r10], %[srcStride]\n" + + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vld1.32 {q8, q9}, [r10], %[srcStride]\n" + "vld1.32 {q10, q11}, [r10], %[srcStride]\n" + + "vswp d1, d8\n" + "vswp d3, d10\n" + "vswp d5, d12\n" + "vswp d7, d14\n" + + "vtrn.32 d16, d20\n" + "vtrn.32 d17, d21\n" + "vtrn.32 d18, d22\n" + "vtrn.32 d19, d23\n" + + "vld1.32 {q12, q13}, [r10], %[srcStride]\n" + "vld1.32 {q14, q15}, [r10], %[srcStride]\n" + + "vtrn.32 d24, d28\n" + "vtrn.32 d25, d29\n" + "vtrn.32 d26, d30\n" + "vtrn.32 d27, d31\n" + + "vswp d17, d24\n" + "vswp d19, d26\n" + "vswp d21, d28\n" + "vswp d23, d30\n" + + "add r10, r12, #16\n" + "vst1.32 {q0}, [r12], %[dstStride]\n" + "vst1.32 {q8}, [r10], %[dstStride]\n" + "vst1.32 {q2}, [r12], %[dstStride]\n" + "vst1.32 {q10}, [r10], %[dstStride]\n" + "vst1.32 {q4}, [r12], %[dstStride]\n" + "vst1.32 {q12}, [r10], %[dstStride]\n" + "vst1.32 {q6}, [r12], %[dstStride]\n" + "vst1.32 {q14}, [r10], %[dstStride]\n" + "vst1.32 {q1}, [r12], %[dstStride]\n" + "vst1.32 {q9}, [r10], %[dstStride]\n" + "vst1.32 {q3}, [r12], %[dstStride]\n" + "vst1.32 {q11}, [r10], %[dstStride]\n" + "vst1.32 {q5}, [r12], %[dstStride]\n" + "vst1.32 {q13}, [r10], %[dstStride]\n" + "vst1.32 {q7}, [r12], %[dstStride]\n" + "vst1.32 {q15}, [r10], %[dstStride]\n" + + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_AVX +/* + Using _mm256_insertf128_ps at the beginning, instead of using _mm256_permute2f128_ps at the end. + On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before. +*/ +inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + const float *src1 = src_ptr + 0 * src_stride; + const float *src2 = src_ptr + 1 * src_stride; + const float *src3 = src_ptr + 2 * src_stride; + const float *src4 = src_ptr + 3 * src_stride; + const float *src5 = src_ptr + 4 * src_stride; + const float *src6 = src_ptr + 5 * src_stride; + const float *src7 = src_ptr + 6 * src_stride; + const float *src8 = src_ptr + 7 * src_stride; + + __m256 r1, r2, r3, r4, r5, r6, r7, r8; + __m256 t1, t2, t3, t4, t5, t6, t7, t8; + // _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency. + r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1); + r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1); + r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1); + r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1); + r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1); + r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1); + r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1); + r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1); + + t1 = _mm256_unpacklo_ps(r1, r2); + t2 = _mm256_unpackhi_ps(r1, r2); + t3 = _mm256_unpacklo_ps(r3, r4); + t4 = _mm256_unpackhi_ps(r3, r4); + t5 = _mm256_unpacklo_ps(r5, r6); + t6 = _mm256_unpackhi_ps(r5, r6); + t7 = _mm256_unpacklo_ps(r7, r8); + t8 = _mm256_unpackhi_ps(r7, r8); + + __m256 v; + v = _mm256_shuffle_ps(t1, t3, 0x4E); + r1 = _mm256_blend_ps(t1, v, 0xCC); + r2 = _mm256_blend_ps(t3, v, 0x33); + + v = _mm256_shuffle_ps(t2, t4, 0x4E); + r3 = _mm256_blend_ps(t2, v, 0xCC); + r4 = _mm256_blend_ps(t4, v, 0x33); + + v = _mm256_shuffle_ps(t5, t7, 0x4E); + r5 = _mm256_blend_ps(t5, v, 0xCC); + r6 = _mm256_blend_ps(t7, v, 0x33); + + v = _mm256_shuffle_ps(t6, t8, 0x4E); + r7 = _mm256_blend_ps(t6, v, 0xCC); + r8 = _mm256_blend_ps(t8, v, 0x33); + + STORE256X8_F32(dst_ptr, dst_stride, r); +} +#endif + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + __m128 v0_ma = _mm_loadu_ps(src_ptr); + __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride); + __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride); + __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride); + + __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma); +} +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) { + // nchw to nc4hw4 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float *src_kernel = (float *)src + i * 9; + float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4; + for (int y = 0; y < 3; y++) { + float g0 = src_kernel[3 * y]; + float g1 = src_kernel[3 * y + 1]; + float g2 = src_kernel[3 * y + 2]; + + dst_kernel[16 * y] = g0; + dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2); + dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2); + dst_kernel[16 * y + 12] = g2; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h new file mode 100644 index 00000000..4558e934 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_H_ +#define MINDSPORE_NNACL_FP32_PACK_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +static inline void transpose_tail(const float *from, float *to, int j_start, int j_end, int i_start, int i_end, + int j_stride, int i_stride) { + // write consecutively + for (int j = j_start; j < j_end; j++) { + for (int i = i_start; i < i_end; i++) { + to[j * j_stride + i] = from[i * i_stride + j]; + } + } +} +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end); +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile); +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); +// Note: If not multithreaded, please set task_id = 0 and thread_count = 0; +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num); +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel); +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel); +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col); + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel); +#endif + +// Transpose 8X8 Fp32 block data +typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#ifdef ENABLE_ARM64 +void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#ifdef ENABLE_ARM32 +void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +#ifdef ENABLE_AVX +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c new file mode 100644 index 00000000..a58fd3c4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c @@ -0,0 +1,292 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pack_fp32_opt.h" +#include "nnacl_c/op_base.h" + +void RowMajor2Col12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + size_t stride_unit = stride * (C12NUM - 1); + int64_t r = 0; + for (; r <= row - C12NUM; r += C12NUM) { + int64_t c = 0; + for (; c <= col - C4NUM; c += C4NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + + "ld1 {v0.4s}, [x9], %[stride_byte]\n" + "ld1 {v1.4s}, [x9], %[stride_byte]\n" + "ld1 {v2.4s}, [x9], %[stride_byte]\n" + "ld1 {v3.4s}, [x9], %[stride_byte]\n" + + "ld1 {v4.4s}, [x9], %[stride_byte]\n" + "ld1 {v5.4s}, [x9], %[stride_byte]\n" + "ld1 {v6.4s}, [x9], %[stride_byte]\n" + "ld1 {v7.4s}, [x9], %[stride_byte]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x9], %[stride_byte]\n" + "ld1 {v9.4s}, [x9], %[stride_byte]\n" + "ld1 {v10.4s}, [x9], %[stride_byte]\n" + "ld1 {v11.4s}, [x9], %[stride_byte]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride_byte] "r"(stride_byte) + : "memory", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + dst_c += C48NUM; + src_c += C4NUM; + } + for (; c < col; ++c) { + for (int i = 0; i < C12NUM; ++i) { + dst_c[i] = src_c[i * stride]; + } + ++src_c; + dst_c += C12NUM; + } + src_c += stride_unit; + } + for (; r < row; ++r) { + for (int c = 0; c < col; ++c) { + dst_c[c * C12NUM] = src_c[c]; + } + src_c += stride; + ++dst_c; + } +} + +void RowMajor2Row12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + int64_t c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + "mov x11, %[row]\n" + "1:\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [x9], %[stride_byte]\n" + "st1 {v0.4s, v1.4s, v2.4s}, [x10], #48\n" + "subs x11, x11, #1\n" + "bgt 1b\n" + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride_byte] "r"(stride_byte), [row] "r"(row) + : "cc", "memory", "x9", "x10", "x11", "v0", "v1", "v2"); + dst_c += row * C12NUM; + src_c += C12NUM; + } + int64_t c_remain = col - c; + if (c_remain == 0) { + return; + } + for (int64_t r = 0; r < row; ++r) { + for (c = 0; c < c_remain; ++c) { + dst_c[r * C12NUM + c] = src_c[c]; + } + src_c += stride; + } +} + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_row = UP_DIV(row, C12NUM); + int64_t unit_num_per_batch = bundle_row * col; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (col == 0) { + return; + } + int64_t start_row = start_remain / col; + int64_t end_row = end_remain / col; + int64_t start_col = start_remain % col; + int64_t end_col = end_remain % col; + const float *src = src_ptr + start_batch * row * col + start_row * C12NUM * col + start_col; + float *dst = dst_ptr + start * C12NUM; + int64_t row_num = C12NUM; + if (start_row * C12NUM + C12NUM > row) { + row_num -= (start_row * C12NUM + C12NUM - row); + } + if (start_batch == end_batch) { + if (start_row == end_row) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col - start_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, C12NUM, col - start_col); + src += C12NUM * col - start_col; + dst += (col - start_col) * C12NUM; + ++start_row; + if (start_row < end_row) { + row_num = (end_row - start_row) * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col - start_col); + src += row_num * col - start_col; + dst += (col - start_col) * C12NUM; + row_num = row - start_row * C12NUM - C12NUM; + if (row_num > 0) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += UP_DIV(row_num, C12NUM) * C12NUM * col; + } + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Col12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_row * C12NUM * col; + } + if (end_row > 0) { + row_num = end_row * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); +} + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_col = UP_DIV(col, C12NUM); + int64_t unit_num_per_batch = bundle_col * row; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (row == 0) { + return; + } + int64_t start_row = start_remain % row; + int64_t end_row = end_remain % row; + int64_t start_col = start_remain / row; + int64_t end_col = end_remain / row; + const float *src = src_ptr + start_batch * row * col + start_row * col + start_col * C12NUM; + float *dst = dst_ptr + start * C12NUM; + int64_t col_num = C12NUM; + if (start_col * C12NUM + C12NUM > col) { + col_num -= (start_col * C12NUM + C12NUM - col); + } + if (start_batch == end_batch) { + if (start_col == end_col) { + RowMajor2Row12MajorOptCore(src, dst, col, end_row - start_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += C12NUM - start_row * col; + dst += (row - start_row) * C12NUM; + ++start_col; + if (start_col < end_col) { + col_num = (end_col - start_col) * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += col_num - start_row * col; + dst += (row - start_row) * C12NUM; + col_num = col - start_col * C12NUM - C12NUM; + if (col_num > 0) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += UP_DIV(col_num, C12NUM) * C12NUM * row; + } + src += (row - 1) * col; + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_col * C12NUM * row; + } + if (end_col > 0) { + col_num = end_col * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h new file mode 100644 index 00000000..95a039cb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_FP32_V2_H +#define MINDSPORE_NNACL_FP32_PACK_FP32_V2_H + +#ifdef ENABLE_ARM64 +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Plan of packing supports granular multi-threads. + */ + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +#ifdef __cplusplus +} +#endif +#endif +#endif // MINDSPORE_NNACL_FP32_PACK_FP32_V2_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c new file mode 100644 index 00000000..5115e4d1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pad_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num) { + if (thread_num == 0) { + return; + } + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * (int)(sizeof(float))); + } + } + } + } + } +} + +int TransOut2InputDimIndex(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + int in_index = MSMAX(index_sum - out_dim_index, offset); + return MSMIN(in_index, in_dim - 1); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset) { + int in_flatten_index = 0; + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndex(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, paddings, mirror_offset)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h new file mode 100644 index 00000000..762000f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pad_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_PAD_FP32_H_ +#define NNACL_FP32_PAD_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pad_parameter.h" + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset); +#ifdef __cplusplus +extern "C" { +#endif +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num); +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c new file mode 100644 index 00000000..a3653b94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.c @@ -0,0 +1,786 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/pooling_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_fp32_simd.h" + +int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + NNACL_CHECK_TRUE_RET(real_win_h_end > real_win_h_start, NNACL_ERR); + NNACL_CHECK_TRUE_RET(real_win_w_end > real_win_w_start, NNACL_ERR); + SIMD_RUN_NO_SCALAR(AvgPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + float tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } // win_w loop + } // win_h loop + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; c < channel; c += c_xtile) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_avg = 0.0; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_avg += cur_input_index[0]; + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0); + MS_FLOAT32X8 tmp_avg2 = MS_MOV256_F32(0); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0); +#else + float tmp_avg = 0; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(cur_input_index)); +#else + tmp_avg += cur_input_index[0]; +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_avg2 = MS_ADD256_F32(tmp_avg2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count)); + tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8); + tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_avg); + + tmp_avg2 = MS_DIV256_F32(tmp_avg2, MS_MOV256_F32(real_count)); + tmp_avg2 = MS_MAX256_F32(tmp_avg2, min_value_8); + tmp_avg2 = MS_MIN256_F32(tmp_avg2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_avg2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count)); + tmp_avg = MS_MAXQ_F32(tmp_avg, min_value); + tmp_avg = MS_MINQ_F32(tmp_avg, max_value); + MS_STQ_F32(dst_c_ptr, tmp_avg); +#else + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = AvgPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + SIMD_RUN_NO_SCALAR(MaxPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + float *dst_c_ptr = dst_plane_ptr + ci; + const float *src_c_ptr = src_plane_ptr + ci; + float tmp_max = -FLT_MAX; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = fmaxf(tmp_max, src_win_ptr[0]); + } // win_w loop + } // win_h loop + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + cur_c * in_plane; + for (; c < channel; c++) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_max = -FLT_MAX; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_max = fmaxf(tmp_max, cur_input_index[0]); + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX); + MS_FLOAT32X8 tmp_max2 = MS_MOV256_F32(-FLT_MAX); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX); +#else + float tmp_max = -FLT_MAX; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(cur_input_index)); +#else + tmp_max = fmaxf(tmp_max, cur_input_index[0]); +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_max2 = MS_MAX256_F32(tmp_max2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_max = MS_MAX256_F32(tmp_max, min_value_8); + tmp_max = MS_MIN256_F32(tmp_max, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_max); + + tmp_max2 = MS_MAX256_F32(tmp_max2, min_value_8); + tmp_max2 = MS_MIN256_F32(tmp_max2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_max2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, min_value); + tmp_max = MS_MINQ_F32(tmp_max, max_value); + MS_STQ_F32(dst_c_ptr, tmp_max); +#else + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = MaxPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, in_size_d); + d_start = MSMAX(d_start, 0); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, in_size_h); + h_start = MSMAX(h_start, 0); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, in_size_w); + w_start = MSMAX(w_start, 0); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + int c_idx = 0; + SIMD_RUN_NO_SCALAR(MaxPooling3D, c_idx, src_batch_ptr, channel, out, d_start, d_end, h_start, h_end, w_start, w_end, + d_stride, h_stride); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_max = -FLT_MAX; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = MSMAX(input[0], tmp_max); + } + } + } + dst_c_ptr[0] = tmp_max; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} + +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_r_h = pooling_param->pooling_parameter_.pad_d_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int pad_r_w = pooling_param->pooling_parameter_.pad_r_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + int pad_r_d = pooling_param->pad_b_; + bool count_include_pad = pooling_param->count_include_pad_; + int divisor = pooling_param->divisor_override_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + const int d_max = in_size_d + pad_r_d; + const int h_max = in_size_h + pad_r_h; + const int w_max = in_size_w + pad_r_w; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, d_max); + int d_start2 = MSMAX(d_start, 0); + int d_end2 = MSMIN(d_end, in_size_d); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, h_max); + int h_start2 = MSMAX(h_start, 0); + int h_end2 = MSMIN(h_end, in_size_h); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, w_max); + int w_start2 = MSMAX(w_start, 0); + int w_end2 = MSMIN(w_end, in_size_w); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + + if (pooling_param->divisor_override_ == 0) { + if (count_include_pad) { + divisor = (d_end - d_start) * (h_end - h_start) * (w_end - w_start); + } else { + divisor = (d_end2 - d_start2) * (h_end2 - h_start2) * (w_end2 - w_start2); + } + } + + int c_idx = 0; + SIMD_RUN_NO_SCALAR(AvgPooling3D, c_idx, src_batch_ptr, channel, out, d_start2, d_end2, h_start2, h_end2, w_start2, + w_end2, d_stride, h_stride, divisor); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_avg = 0; + for (int dd = d_start2; dd < d_end2; ++dd) { + for (int hh = h_start2; hh < h_end2; ++hh) { + for (int ww = w_start2; ww < w_end2; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = tmp_avg + input[0]; + } + } + } + dst_c_ptr[0] = tmp_avg / (float)divisor; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h new file mode 100644 index 00000000..10bb1210 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_POOLING_H_ +#define NNACL_FP32_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in new file mode 100644 index 00000000..bdfb829f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/pooling_fp32_simd.h.in @@ -0,0 +1,116 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int AvgPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg = SIMD_ADD_F32(tmp_avg, SIMD_LD_F32(src_win_ptr)); + ++real_count; + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(real_count)); + tmp_avg = SIMD_MAX_F32(tmp_avg, min_val); + tmp_avg = SIMD_MIN_F32(tmp_avg, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return ci; +} + +static inline int MaxPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_max = min_val; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = SIMD_MAX_F32(tmp_max, SIMD_LD_F32(src_win_ptr)); + } + } + tmp_max = SIMD_MIN_F32(tmp_max, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return ci; +} + +static inline int MaxPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_max = SIMD_MOV_F32(-FLT_MAX); + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = SIMD_MAX_F32(SIMD_LD_F32(input), tmp_max); + } + } + } + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return c_idx; +} + +static inline int AvgPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride, int divisor) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = SIMD_ADD_F32(SIMD_LD_F32(input), tmp_avg); + } + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(divisor)); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return c_idx; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c new file mode 100644 index 00000000..4113e51c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.c @@ -0,0 +1,70 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/power_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/power_fp32_simd.h" + +float OptimizedPowerScalar(float x, const float *exponent) { + int exp = abs((int)(*exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return *exponent >= 0 ? result : 1 / result; +} + +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + PowerScalarFun PowerScalarFun_ = NULL; + + int i = 0; + if (CheckInteger(*exponent)) { + PowerScalarFun_ = OptimizedPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastIntExponent, i, input, (int)(*exponent), output, len, scale, shift); + } else { + PowerScalarFun_ = StdPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastFloatExponent, i, input, *exponent, output, len, scale, shift); + } + + for (; i < len; ++i) { + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent); + } +} + +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + int i = 0; + + SIMD_RUN_NO_SCALAR(PowerSingleExponent, i, input, exponent, output, len, scale, shift); + PowerScalarFun PowerScalarFun_ = NULL; + for (; i < len; ++i) { + PowerScalarFun_ = CheckInteger(exponent[i]) ? OptimizedPowerScalar : StdPowerScalar; + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent + i); + } +} + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFun PowerFun_ = NULL; + PowerFun_ = broadcast ? PowerBroadCast : PowerSingle; + PowerFun_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h new file mode 100644 index 00000000..b2a2fb79 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_POWER_FP32_H_ +#define MINDSPORE_NNACL_FP32_POWER_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" + +typedef void (*PowerFun)(const float *, const float *, float *, int, float, float); +typedef float (*PowerScalarFun)(float x, const float *exponent); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckInteger(float f) { return fabsf(f - (int)(f)) < 0.000001; } + +static inline float StdPowerScalar(float x, const float *exponent) { return powf(x, *exponent); } + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast); +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift); +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_POWER_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in new file mode 100644 index 00000000..2b9f398c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/power_fp32_simd.h.in @@ -0,0 +1,94 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int PowerBroadCastIntExponent@SIMD_INSTRUCTION@(int index, const float *input, int exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result = SIMD_MOV_F32(1.0f); + int exp = abs(exponent); + while (exp) { + if (exp % 2) { + result = SIMD_MUL_F32(result, tmp); + } + tmp = SIMD_MUL_SQUARE_F32(tmp); + exp = exp / 2; + } + SIMD_ST_F32(output + index, exponent >= 0 ? result : SIMD_DIV_F32(SIMD_MOV_F32(1), result)); + } + return index; +} + +static inline int PowerBroadCastFloatExponent@SIMD_INSTRUCTION@(int index, const float *input, float exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result; + for (int i = 0; i < BLOCK_NUM; ++i) { + SIMD_F32_GETI(result, i) = powf(SIMD_F32_GETI(tmp, i), exponent); + } + SIMD_ST_F32(output + index, result); + } + return index; +} + +static inline int PowerSingleExponent@SIMD_INSTRUCTION@(int index, const float *input, const float *exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + for (int j = 0; j < BLOCK_NUM; ++j) { + float cur_exponent = exponent[index + j]; + float cur_val = SIMD_F32_GETI(tmp_vec, j); + if (fabsf(cur_exponent - (int)(cur_exponent)) < 0.000001) { + int exp = abs((int)(cur_exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= cur_val; + } + cur_val *= cur_val; + exp = exp / 2; + } + output[index + j] = *exponent >= 0 ? result : 1 / result; + } else { + output[index + j] = powf(cur_val, cur_exponent); + } + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c new file mode 100644 index 00000000..d57e63b9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/prelu_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#ifdef ENABLE_ARM64 +static inline void PRelu4x16(const float *in, float *out, const float *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n" + : + : [in] "r"(in), [out] "r"(out), [cur_slope] "r"(cur_slope), [step] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i < end - 3; i += 4) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; + for (; j < channel - 15; j += 16) { + const float *in = cur_in + j; + float *out = cur_out + j; + const float *cur_slope = slope + j; + size_t step = channel * sizeof(float); + PRelu4x16(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; j < channel - 3; j += 4) { + MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); + MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); + MS_FLOAT32X4 mul = MS_MULQ_F32(in, s); + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 res = MS_BLENDQ_F32(in, mul, MS_CMPLEQ_F32(in, zero)); + MS_STQ_F32(cur_out + j, res); + } +#endif + for (; j < channel; j++) { + if (cur_in[j] > 0) { + cur_out[j] = cur_in[j]; + } else { + cur_out[j] = cur_in[j] * slope[j]; + } + } + } +} + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end) { + int i = start; +#if defined(ENABLE_AVX) +#define mask_offset 30 + for (; i <= end - C8NUM; i += C8NUM) { + MS_FLOAT32X8 src_tmp = MS_LD256_F32(input + i); + MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, slope); + MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), mask_offset); + MS_ST256_F32(output + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask)); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= end - C4NUM; i += C4NUM) { + MS_FLOAT32X4 src_tmp = MS_LDQ_F32(input + i); + MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, slope); +#ifdef ENABLE_ARM + MS_UINT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#else + MS_FLOAT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#endif + MS_STQ_F32(output + i, MS_BLENDQ_F32(src_tmp, mul_tmp, mask)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h new file mode 100644 index 00000000..3c3c7f2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prelu_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRELU_H_ +#define MINDSPORE_NNACL_FP32_PRELU_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel); + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRELU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h new file mode 100644 index 00000000..2ef55f69 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/prior_box_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ +#define MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +static int PriorBox(const float *input_data, float *output_data, const size_t size, const int tid, + const int thread_num) { + NNACL_CHECK_NULL_RETURN_ERR(input_data); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + size_t unit_size = size / thread_num; + size_t copy_size = (tid == thread_num - 1) ? size - unit_size * tid : unit_size; + (void)memcpy(output_data + tid * unit_size, input_data + tid * unit_size, copy_size * sizeof(float)); + return NNACL_OK; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c new file mode 100644 index 00000000..b73d4c81 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/ragged_range_fp32.h" +#include +#include "nnacl_c/op_base.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + float start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + float limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + float delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} + +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + int start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + int limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + int delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN(delta); + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h new file mode 100644 index 00000000..e44846dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/ragged_range_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_RAGGED_RANGE_FP32_H_ +#define NNACL_FP32_RAGGED_RANGE_FP32_H_ + +#include "nnacl_c/kernel/ragged_range.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range); +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range); + +#endif // NNACL_FP32_RAGGED_RANGE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h new file mode 100644 index 00000000..2058ff1f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/range_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_RANGE_FP32_H_ +#define NNACL_FP32_RANGE_FP32_H_ + +#include "nnacl_c/op_base.h" + +void Range(float *output_ptr, float start, float delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +void RangeInt(int32_t *output_ptr, int start, int delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP32_RANGE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h new file mode 100644 index 00000000..fa7857e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rank_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_H_ +#define MINDSPORE_NNACL_RANK_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline void Rank(float *output, int rank) { + output[0] = (float)(rank); + return; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_RANK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c new file mode 100644 index 00000000..d8931b34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.c @@ -0,0 +1,359 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reduce_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/reduce_fp32_simd.h" +#ifdef ENABLE_NNACL_INFER_SHAPE +#include "nnacl_c/reduce_parameter.h" +#endif + +// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) +#define ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k) \ + for (; k < inner_size; k++) { \ + const op_type *inner_src = outer_src + k; \ + op_name##PreDeal; \ + for (int i = 0; i < axis_size; i++) { \ + op_name##MidCalc; \ + } \ + op_name##PostDeal; \ + } + +#define RegReduceOp(op_name, op_type) \ + int op_name(int outer_size, int inner_size, int axis_size, const op_type *src_data, op_type *dst_data, int tid, \ + int thread_num) { \ + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); \ + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); \ + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); \ + for (int j = tid; j < outer_size; j += thread_num) { \ + const op_type *outer_src = src_data + j * axis_size * inner_size; \ + op_type *outer_dst = dst_data + j * inner_size; \ + int k = 0; \ + SIMD_RUN_NO_SCALAR(op_name, k, outer_src, outer_dst, inner_size, axis_size); \ + \ + ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k); \ + } \ + return NNACL_OK; \ + } + +// ReduceSum +#define ReduceSumPreDeal float tmp = 0; +#define ReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define ReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSum, float); + +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp += src_tmp[i]; + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceMean +#define ReduceMeanPreDeal float tmp = 0; +#define ReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define ReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(ReduceMean, float); + +// ReduceMin +#define ReduceMinPreDeal float tmp = FLT_MAX; +#define ReduceMinMidCalc tmp = fminf(tmp, inner_src[i * inner_size]); +#define ReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMin, float); + +// ReduceMax +#define ReduceMaxPreDeal float tmp = FLT_MIN; +#define ReduceMaxMidCalc tmp = fmaxf(tmp, inner_src[i * inner_size]); +#define ReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMax, float); + +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceMaxByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp = fmaxf(tmp, src_tmp[i]); + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceProd +#define ReduceProdPreDeal float tmp = 1.0f; +#define ReduceProdMidCalc tmp *= inner_src[i * inner_size]; +#define ReduceProdPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceProd, float); + +// ReduceSumSquare +#define ReduceSumSquarePreDeal float tmp = 0; +#define ReduceSumSquareMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceSumSquarePostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSumSquare, float); + +// ReduceL2Norm +#define ReduceL2NormPreDeal float tmp = 0; +#define ReduceL2NormMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceL2NormPostDeal outer_dst[k] = sqrt(tmp); +RegReduceOp(ReduceL2Norm, float); + +// IntReduceSum +#define IntReduceSumPreDeal int tmp = 0; +#define IntReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceSum, int32_t); + +// IntReduceMean +#define IntReduceMeanPreDeal int tmp = 0; +#define IntReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(IntReduceMean, int32_t); + +// IntReduceMin +#define IntReduceMinPreDeal int tmp = INT32_MAX; +#define IntReduceMinMidCalc tmp = MSMIN(tmp, inner_src[i * inner_size]); +#define IntReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMin, int32_t); + +// IntReduceMax +#define IntReduceMaxPreDeal int tmp = INT32_MIN; +#define IntReduceMaxMidCalc tmp = MSMAX(tmp, inner_src[i * inner_size]); +#define IntReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMax, int32_t); + +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const bool *outer_src = src_data + j * axis_size * inner_size; + bool *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const bool *inner_src = outer_src + k; + bool *inner_dst = outer_dst + k; + bool tmp = true; + for (i = 0; i < axis_size; i++) { + tmp = tmp && inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int tmp = 1; + for (i = 0; i < axis_size; i++) { + if (isMulOverflow(tmp, inner_src[i * inner_size])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param) { + *out_format = in_format[0]; + *out_datatype = in_datatype[0]; + ReduceParameter *reduce_parameter = (ReduceParameter *)param; + bool keep_dims = reduce_parameter->keep_dims_; + int num_axes = reduce_parameter->num_axes_; + int32_t *in_shape0 = in_shape[0]; + int rank = dim_size[0]; + NNACL_CHECK_TRUE_RET(rank > 0 && rank <= REDUCE_MAX_AXES_NUM, NNACL_PARAM_INVALID); + int axes[REDUCE_MAX_AXES_NUM]; + int actual_axes_num = num_axes; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_TRUE_RET(reduce_parameter->axes_[i] >= -rank && reduce_parameter->axes_[i] < rank, NNACL_PARAM_INVALID); + if (reduce_parameter->axes_[i] < 0) { + axes[i] = reduce_parameter->axes_[i] + rank; + } else { + axes[i] = reduce_parameter->axes_[i]; + } + } + if (reduce_parameter->reduce_to_end_) { + NNACL_CHECK_TRUE_RET(num_axes == 1, NNACL_PARAM_INVALID); + int begin_axis = axes[0]; + num_axes = rank - begin_axis; + for (int i = begin_axis + 1; i < rank; ++i) { + axes[actual_axes_num++] = i; + } + } + if (num_axes == 0) { + int j = 0; + for (int i = 0; i < rank; ++i) { + axes[i] = i; + if (keep_dims) { + out_shape[j++] = 1; + } + } + reduce_parameter->num_axes_ = rank; + for (int i = 0; i < rank; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; + } + // reduce on selected axes + int j = 0; + for (int i = 0; i < rank; ++i) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (axes[idx] == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape[j++] = 1; + } + } else { + out_shape[j++] = in_shape0[i]; + } + } + reduce_parameter->num_axes_ = num_axes; + for (int i = 0; i < num_axes; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; +} +#endif + +// [A, B] -> [B] +// col_size : start -> end for parallel +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + + size_t k = 0; + SIMD_RUN_NO_SCALAR(ReduceSumDim2Axis0, k, col_size, col_len, row_len, src_data, dst_data); + for (; k < col_size; k++) { + const float *inner_src = src_data + k; + float *inner_dst = dst_data + k; + float tmp = 0.0f; + for (size_t i = 0; i < row_len; i++) { + tmp += inner_src[i * col_len]; + } + *inner_dst = tmp; + } + return NNACL_OK; +} + +// [A, B] -> [A] +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + size_t k = 0; + float tmp = 0; +#ifdef ENABLE_AVX + size_t block_mod = col_len % C8NUM; + size_t block_c8 = col_len - block_mod; + float tmp_arr[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 tmp_arr_8 = MS_MOV256_F32(tmp_arr[0]); + for (; k < block_c8; k += C8NUM) { + MS_FLOAT32X8 src_in = MS_LD256_F32(src_data + k); + tmp_arr_8 = MS_ADD256_F32(tmp_arr_8, src_in); + } + MS_ST256_F32(tmp_arr, tmp_arr_8); + for (size_t i = 0; i < 8; ++i) { + tmp += tmp_arr[i]; + } +#endif + for (; k < col_len; k++) { + tmp += src_data[k]; + } + dst_data[0] = tmp; + return NNACL_OK; +} + +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + float sum = 0.0; + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_data, &sum, 0); + for (; i < size; ++i) { + sum += src_data[i]; + } + *mean = sum / size; + return NNACL_OK; +} + +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + int64_t i = 0; + SIMD_RUN_NO_SCALAR(FloatReduceDeviation, i, src_data, mean, size, deviation); + for (; i < size; ++i) { + float tmp = src_data[i] - mean; + tmp = tmp * tmp; + *deviation += tmp; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h new file mode 100644 index 00000000..c33c4df4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REDUCE_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_H_ +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMean(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceSum(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMax(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMin(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceL2Norm(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num); +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data); +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data); +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size); +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation); + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REDUCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in new file mode 100644 index 00000000..eee95b42 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reduce_fp32_simd.h.in @@ -0,0 +1,220 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_sum, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_sum += SIMD_GET_SUM_F32(tmp); + return index; +} + +static inline int64_t ReduceMean@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, SIMD_DIV_N_F32(tmp, axis_size)); + } + return index; +} + +static inline int64_t ReduceMin@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMax@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMaxByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_max, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(*tmp_max); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_max = SIMD_GET_MAX_F32(tmp); + return index; +} + +static inline int64_t ReduceProd@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(1.0f); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MUL_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumSquare@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceL2Norm@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, SIMD_SQRT_F32(tmp)); + } + return index; +} + +static inline int64_t IntReduceSum@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMean@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, SIMD_DIV_N_EPI32(tmp, axis_size)); + } + return index; +} + +static inline int64_t IntReduceMin@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMax@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumDim2Axis0@SIMD_INSTRUCTION@(int64_t index, size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + for (int block_max_size = col_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + const float *inner_src = src_data + index; + float *inner_dst = dst_data + index; + for (size_t i = 0; i < row_len; ++i) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * col_len)); + } + SIMD_ST_F32(inner_dst, tmp); + } + return index; +} + +static inline int64_t FloatReduceDeviation@SIMD_INSTRUCTION@(int64_t index, const float *src_data, float mean, size_t size, float *deviation) { + SIMD_F32 fs_deviation = SIMD_MOV_F32(0); + SIMD_F32 fs_mean = SIMD_MOV_F32(mean); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 fs_sub = SIMD_LD_F32(src_data + index); + + fs_sub = SIMD_SUB_F32(fs_sub, fs_mean); + SIMD_F32 fs_pow = SIMD_MUL_F32(fs_sub, fs_sub); + fs_deviation = SIMD_ADD_F32(fs_deviation, fs_pow); + } + *deviation += SIMD_GET_SUM_F32(fs_deviation); + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c new file mode 100644 index 00000000..86476f0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.c @@ -0,0 +1,598 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32/resize_fp32.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void CalculateCoordinate(float out, int in, int32_t *bottom, int32_t *top, float *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float top_weight = (float)out - (float)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFunc(float a, const float x, float *weight) { + float abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubic(float out, int in, int32_t *index, float *weights, float a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float distance[4] = {-1, 0, 1, 2}; + float tmp_dis = out - (float)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFunc(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateCoordinate(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateCoordinate(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubic(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubic(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + float actual_x; + float actual_y; + + for (int b = 0; b < batch; b++) { + const float *box = boxes + b * 4; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + + int32_t *y_bottom = y_bottoms + b * new_height; + int32_t *y_top = y_tops + b * new_height; + float *y_bottom_weight = y_bottom_weights + b * new_height; + int32_t *x_left = x_lefts + b * new_width; + int32_t *x_right = x_rights + b * new_width; + float *x_left_weight = x_left_weights + b * new_width; + for (int h = 0; h < new_height; h++) { + if (new_height > 1) { + actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (in_h - 1); + } + CalculateCoordinate(actual_y, in_h, y_bottom + h, y_top + h, y_bottom_weight + h); + } + for (int w = 0; w < new_width; w++) { + if (new_width > 1) { + actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (in_w - 1); + } + CalculateCoordinate(actual_x, in_w, x_left + w, x_right + w, x_left_weight + w); + } + } + return NNACL_OK; +} + +int InterpRow(const float *src_line, float *linear_output, int new_width, const float *x_left_weights, + const int32_t *x_lefts, const int32_t *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]); + MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(left, left_w_8), MS_MUL256_F32(right, right_w_8)); + MS_ST256_F32(linear_output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]); + MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(left, left_w), MS_MULQ_F32(right, right_w)); + MS_STQ_F32(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float left = src_line[left_w_offset + c]; + float right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpCol(const float *bottom_line, const float *top_line, float *output, int new_width, float y_bottom_weight, + int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight); + MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c); + MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(bottom, bottom_w_8), MS_MUL256_F32(top, top_w_8)); + MS_ST256_F32(output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight); + MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c); + MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(bottom, bottom_w), MS_MULQ_F32(top, top_w)); + MS_STQ_F32(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float bottom = bottom_line[w * in_c + c]; + float top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void Bilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottom, const int32_t *y_top, const int32_t *x_left, const int32_t *x_right, + const float *y_bottom_weight, const float *x_left_weight, float *line0, float *line1, const int h_begin, + const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float *const cache_line_ptr[2] = {line0, line1}; + float *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRow(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpCol(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float *input = input_data + b * in_h * in_w * in_c; + float *output = output_data + b * new_height * new_width * in_c; + Bilinear(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRow(const float *src, float *dst, const float *weights, const int32_t *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float *weight = weights + 4 * w; + float *dst_w = dst + w * channel; + const float *src0_w = src + lefts[4 * w] * channel; + const float *src1_w = src + lefts[4 * w + 1] * channel; + const float *src2_w = src + lefts[4 * w + 2] * channel; + const float *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpCol(const float *src, float *dst, const float *weights, int width, int channel) { + const float *src0 = src; + const float *src1 = src + width * channel; + const float *src2 = src + 2 * width * channel; + const float *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float *dst_w = dst + w * channel; + const float *src0_w = src0 + w * channel; + const float *src1_w = src1 + w * channel; + const float *src2_w = src2 + w * channel; + const float *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void Bicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRow(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpCol(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float *input = input_data + b * input_cube_per_batch; + float *output = output_data + b * output_cube_per_batch; + Bicubic(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int RewriteExtrapolationValue(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + float *output = output_data + b * new_height * new_width * new_channel; + const float *box = boxes + 4 * b; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + float actual_y, actual_x; + for (int h = h_begin; h < h_end; ++h) { + if (new_height > 1) { + actual_y = start_h * (input_h - 1) + h * (end_h - start_h) * (input_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (input_h - 1); + } + if (actual_y < 0 || actual_y > input_h - 1) { + float *output_data_base = output + h * new_width * new_channel; + for (int x = 0; x < new_width; ++x) { + for (int d = 0; d < new_channel; ++d) { + *output_data_base = extrapolation_value; + output_data_base++; + } + } + } + for (int w = 0; w < new_width; ++w) { + if (new_width > 1) { + actual_x = start_w * (input_w - 1) + w * (end_w - start_w) * (input_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (input_w - 1); + } + if (actual_x < 0 || actual_x > input_w - 1) { + float *output_data_base = output + h * new_width * new_channel + w * new_channel; + for (int d = 0; d < new_channel; ++d) { + output_data_base[d] = extrapolation_value; + } + } + } + } + } + return NNACL_OK; +} + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL || + y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || + x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + const float *cur_img = input_data + box_idx[b] * input_h * input_w * new_channel; + const int32_t *y_bottom = y_bottoms + b * new_height; + const int32_t *y_top = y_tops + b * new_height; + const float *y_bottom_weight = y_bottom_weights + b * new_height; + const int32_t *x_left = x_lefts + b * new_width; + const int32_t *x_right = x_rights + b * new_width; + const float *x_left_weight = x_left_weights + b * new_width; + float *output = output_data + b * new_height * new_width * new_channel; + + Bilinear(cur_img, output, input_shape, output_shape, y_bottom, y_top, x_left, x_right, y_bottom_weight, + x_left_weight, line0, line1, h_begin, h_end); + } + RewriteExtrapolationValue(input_data, output_data, box_idx, boxes, extrapolation_value, input_shape, output_shape, + y_tops, h_begin, h_end); + return NNACL_OK; +} + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); + } + } + } + return NNACL_OK; +} + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + return (float)(x_resized) / scale; +} + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized - 1) / (float)(length_original - 1); + return (float)(x_resized) / scale; +} + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + float actual = (float)(x_resized + 0.5) / scale - 0.5; + return actual > 0 ? actual : 0; +} + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch) { + for (int i = 0; i < num_boxes; i++) { + if (box_idx[i] < 0 || box_idx[i] >= batch) { + return NNACL_ERR; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h new file mode 100644 index 00000000..951b0fee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/resize_fp32.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_RESIZE_H_ +#define MINDSPORE_NNACL_FP32_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/resize_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef float (*CalculateOriginalCoordinate)(int x_resized, int length_original, int length_resized); + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff); + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end); + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end); + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end); + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized); + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized); + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized); + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c new file mode 100644 index 00000000..0886c07e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reverse_fp32.h" +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/nnacl_utils.h" + +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index) { + for (size_t i = 0; i < elem_size; i++) { + NNACL_ASSERT(index[i] >= 0); + output[index[i]] = input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h new file mode 100644 index 00000000..d95ac633 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_REVERSE_FP32_H_ +#define NNACL_FP32_REVERSE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/reverse_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_REVERSE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c new file mode 100644 index 00000000..cb93db87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/reverse_sequence_fp32.h" + +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para) { + (void)memcpy(output, input0, para->total_data_size_); + ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); + ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); + for (int i = 0; i < para->outer_count_; ++i) { + const float *in = input0 + i * para->outer_stride_; + float *out = output + i * para->outer_stride_; + for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { + const float *in_batch = in + batch * para->input_stride_[para->batch_axis_]; + float *out_batch = out + batch * para->output_stride_[para->batch_axis_]; + int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch); + NNACL_CHECK_TRUE_RET_VOID(seq_length <= para->input_shape0_[para->seq_axis_]); + for (int n = 0; n < seq_length; ++n) { + const float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_]; + float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; + for (int j = 0; j < para->inner_count_; ++j) { + (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h new file mode 100644 index 00000000..2538f22a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/reverse_sequence_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ +#define MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ + +#include +#include "nnacl_c/common_func.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/reverse_sequence_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c new file mode 100644 index 00000000..fbdf194f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.c @@ -0,0 +1,147 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/rmsprop_fp32.h" +#ifdef ENABLE_SSE +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif + +#ifdef ENABLE_AVX +#include +#endif + +#include +#include "nnacl_c/errorcode.h" + +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_square_ptr = mean_square + start; + float *gradients_ptr = gradients + start; + float *moment_ptr = moment + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + gradient_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r); + avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1); + + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2); + _mm256_storeu_ps(moment_ptr, avx_r1); + + variable_r = _mm256_loadu_ps(variable_ptr); + variable_r = _mm256_sub_ps(variable_r, avx_r1); + _mm256_storeu_ps(variable_ptr, variable_r); + + gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + moment_ptr += C8NUM; + variable_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(mean_square[c1] + epsilon); + variable[c1] -= moment[c1]; + } + return NNACL_OK; +} + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_gradients_ptr = mean_gradients + start; + float *mean_square_ptr = mean_square + start; + float *moment_ptr = moment + start; + float *gradients_ptr = gradients + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r; + __m256 avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr); + avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r); + mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1); + _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r); + + avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r)); + __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r); + __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS); + __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r); + + avx_r1 = _mm256_mul_ps(grad_r, lr_r); + avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r)); + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_mul_ps(moment_r, momentum_r); + avx_r1 = _mm256_add_ps(avx_r1, avx_r2); + moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r); + _mm256_storeu_ps(moment_ptr, moment_r); + + variable_r = _mm256_loadu_ps(variable_ptr); + avx_r1 = _mm256_sub_ps(variable_r, moment_r); + variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r); + _mm256_storeu_ps(variable_ptr, variable_r); + + variable_ptr += C8NUM; + mean_gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + gradients_ptr += C8NUM; + moment_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + mean_gradients[c1] += (gradients[c1] - mean_gradients[c1]) * (1.0 - decay); + float denom = (mean_square[c1] - mean_gradients[c1] * mean_gradients[c1]) + epsilon; + if (denom > 0) { + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(denom); + variable[c1] -= moment[c1]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h new file mode 100644 index 00000000..d0b5f1f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/rmsprop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RMDPROP_FP32_H +#define MINDSPORE_NNACL_RMDPROP_FP32_H + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end); + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RMDPROP_FP32_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c new file mode 100644 index 00000000..6b6940f4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/roi_pooling_fp32.h" +#include +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" + +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param) { + if (param->thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int num_rois = param->output_n_; + int units = UP_DIV(num_rois, param->thread_num_); + int roi_st = tid * units; + int roi_end = MSMIN(num_rois, roi_st + units); + if (roi_st >= num_rois) { + return NNACL_OK; + } + int batch_size = param->input_n_; + int height_ = param->input_h_; + int width_ = param->input_w_; + int channels_ = param->input_c_; + float scale = param->scale_; + int pooled_height = param->pooledH_; + int pooled_width = param->pooledW_; + const int roi_stride = 5; + int roi_ind_st = roi_st * roi_stride; + for (int i = roi_st; i < roi_end; ++i) { + int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index + if (roi_batch_ind >= batch_size) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + int roi_start_h = (int)roundf(roi[roi_ind_st + 1] * scale); // top-left x1 + int roi_start_w = (int)roundf(roi[roi_ind_st + 2] * scale); // top-left y1 + int roi_end_h = (int)roundf(roi[roi_ind_st + 3] * scale); // bottom-right x2 + int roi_end_w = (int)roundf(roi[roi_ind_st + 4] * scale); // bottom-fight y2 + + int roi_height = MSMAX(roi_end_h - roi_start_h + 1, 1); + int roi_width = MSMAX(roi_end_w - roi_start_w + 1, 1); + + float bin_size_h = (float)roi_height / (float)pooled_height; + float bin_size_w = (float)roi_width / (float)pooled_width; + const float *batch_data = in_ptr + param->in_strides_[kNHWC_N] * roi_batch_ind; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = (int)floorf(ph * bin_size_h); // block xi_1 + int wstart = (int)floorf(pw * bin_size_w); // block yi_1 + int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2 + int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2 + + hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_); + hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_); + wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_); + wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_); + bool is_empty = (hend <= hstart) || (wend <= wstart); + for (int j = 0; j < channels_; ++j) { + max_c[j] = is_empty ? 0 : -FLT_MAX; + } + int pooled_index = i * param->out_strides_[0] + ph * param->out_strides_[1] + pw * param->out_strides_[2]; + int bd_index = hstart * param->in_strides_[1]; + for (int h = hstart; h < hend; ++h) { + int wi = bd_index + wstart * param->in_strides_[2]; + for (int w = wstart; w < wend; ++w) { + for (int c = 0; c < channels_; ++c) { + max_c[c] = MSMAX(batch_data[wi + c], max_c[c]); + } + wi += param->in_strides_[2]; + } // in_w end; + bd_index += param->in_strides_[1]; + } // in_h end + for (int j = 0; j < channels_; ++j) { + out_ptr[pooled_index + j] = max_c[j]; + } + } + } + roi_ind_st += roi_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h new file mode 100644 index 00000000..1566f468 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/roi_pooling_fp32.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ROI_POOLING_H_ +#define MINDSPORE_NNACL_FP32_ROI_POOLING_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ROIPoolingParameter { + // primitive parameter + OpParameter op_parameter_; + int pooledW_; + int pooledH_; + float scale_; + + // shape correlative + int in_strides_[DIMENSION_4D]; + int out_strides_[DIMENSION_4D]; + int ndim_; + int input_w_; + int input_h_; + int input_n_; + int input_c_; + int output_w_; + int output_h_; + int output_n_; + int output_c_; + + // other parameter + int thread_num_; +} ROIPoolingParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ROI_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c new file mode 100644 index 00000000..32fd2319 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.c @@ -0,0 +1,304 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void ScaleInner(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void ScaleAxis(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void ScaleAxisRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; in_index < inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void ScaleAxisRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h new file mode 100644 index 00000000..40ef1d36 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/scale_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SCALE_FP32_H_ +#define NNACL_FP32_SCALE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/scale.h" +#ifdef __cplusplus +extern "C" { +#endif +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_SCALE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c new file mode 100644 index 00000000..9e3d1d98 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/softmax_fp32.h" +#include +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/softmax_fp32_simd.h" + +void SoftmaxNorm(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + float max = -FLT_MAX; + + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxNormCalcNorm, index, src, dst, cur_batch_offset, max, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + dst[offset] = src[offset] - max; + } + } +} + +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + + // get channel's max value + float max = -FLT_MAX; + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + // get channel's exp sum value + float exp_sum = 0.0f; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetExpSum, index, src, dst, cur_batch_offset, max, &exp_sum, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + float exp_out = simd_exp32_f32(src[offset] - max); + exp_sum += exp_out; + dst[offset] = exp_out; + } + + // get result + NNACL_CHECK_TRUE_RET(exp_sum != 0, NNACL_ERR); + exp_sum = 1.0f / exp_sum; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetResult, index, dst, dst, cur_batch_offset, exp_sum, channel); + for (; index < channel; index++) { + dst[cur_batch_offset + index] = dst[cur_batch_offset + index] * exp_sum; + } + } + return NNACL_OK; +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h new file mode 100644 index 00000000..36d5acc6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape); +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel); +void SoftmaxNorm(const float *src, float *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in new file mode 100644 index 00000000..762d2c75 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_fp32_simd.h.in @@ -0,0 +1,80 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxNormGetMax@SIMD_INSTRUCTION@(int64_t index, const float *src, int cur_batch_offset, + float *max, int channel) { + if (channel >= BLOCK_NUM * BLOCK_NUM) { + SIMD_F32 max_val = SIMD_MOV_F32(*max); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + max_val = SIMD_MAX_F32(max_val, SIMD_LD_F32(src + cur_batch_offset + index)); + } + *max = SIMD_GET_MAX_F32(max_val); + } + return index; +} + +static inline int64_t SoftmaxNormCalcNorm@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, int channel) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 output = SIMD_SUB_F32(SIMD_LD_F32(src + cur_batch_offset + index), SIMD_MOV_F32(max)); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +static inline int64_t SoftmaxLastAxisGetExpSum@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, float *exp_sum, int channel) { +#ifndef _WIN32 + SIMD_F32 sum_val = SIMD_SET0_F32; + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max)); + SIMD_F32 exp_out = SIMD_EXP_F32(output); + sum_val = SIMD_ADD_F32(sum_val, exp_out); + SIMD_ST_F32(dst + cur_batch_offset + index, exp_out); + } + *exp_sum += SIMD_GET_SUM_F32(sum_val); +#endif + return index; +} + +static inline int64_t SoftmaxLastAxisGetResult@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float exp_sum, int channel) { + SIMD_F32 exp_sum_val = SIMD_MOV_F32(exp_sum); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_MUL_F32(input, exp_sum_val); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c new file mode 100644 index 00000000..a9cdf4a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/softmax_grad_fusion_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/softmax_grad_fusion_fp32_simd.h" +#include "nnacl_c/sub_fp32_simd.h" + +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m) { + float result = 0; + + int64_t i = 0; + SIMD_RUN_NO_SCALAR(SoftmaxGradFusionOpt, i, a, b, &result, m); + for (; i < m; i++) { + result += a[i] * b[i]; + } + + i = 0; + SIMD_RUN_NO_SCALAR(ElementOptSubMul, i, a, b, result, dst, m); + for (; i < m; i++) { + dst[i] = a[i] * (b[i] - result); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h new file mode 100644 index 00000000..5c69f9fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in new file mode 100644 index 00000000..01a1de64 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/softmax_grad_fusion_fp32_simd.h.in @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxGradFusionOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + result_vec = SIMD_FMADD_F32(a_vec, b_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t ElementOptSubMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float sum, + float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(sum); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_ST_F32(out + index, SIMD_MUL_F32(vin0, SIMD_SUB_F32(vin1, vin1_opt_))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c new file mode 100644 index 00000000..b2b026b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.c @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/space_to_batch_fp32.h" +#include "nnacl_c/errorcode.h" + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int input_batch = param->input_shape_[0]; + const int input_height = param->input_shape_[1]; + const int input_width = param->input_shape_[2]; + + const int output_batch = param->output_shape_[0]; + const int output_height = param->output_shape_[1]; + const int output_width = param->output_shape_[2]; + + const int block_shape_height = param->block_sizes_[0]; + const int block_shape_width = param->block_sizes_[1]; + const int padding_top = param->paddings_[0]; + const int padding_left = param->paddings_[2]; + + NNACL_CHECK_ZERO_RETURN_ERR(input_batch); + NNACL_CHECK_ZERO_RETURN_ERR(block_shape_width); + int copy_size = param->input_shape_[3] * param->data_type_len; + for (int64_t out_b = task_id; out_b < output_batch; out_b += param->op_parameter_.thread_num_) { + int in_b = out_b % input_batch; + int shift_w = (out_b / input_batch) % block_shape_width; + int shift_h = (out_b / input_batch) / block_shape_width; + for (int out_h = 0; out_h < output_height; out_h++) { + for (int out_w = 0; out_w < output_width; out_w++) { + int64_t output_offset = + out_b * param->out_stride_[0] + out_h * param->out_stride_[1] + out_w * param->out_stride_[2]; + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { + memset((int8_t *)output + output_offset * param->data_type_len, 0, copy_size); + } else { + int in_h = (out_h * block_shape_height + shift_h) - padding_top; + int in_w = (out_w * block_shape_width + shift_w) - padding_left; + int input_offset = in_b * param->in_stride_[0] + in_h * param->in_stride_[1] + in_w * param->in_stride_[2]; + memcpy((int8_t *)output + output_offset * param->data_type_len, + (const int8_t *)input + input_offset * param->data_type_len, copy_size); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h new file mode 100644 index 00000000..89a1abc5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/space_to_batch_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ + +#include +#include "nnacl_c/op_base.h" + +typedef struct SpaceToBatchParameter { + // primitive parameter + OpParameter op_parameter_; + int block_sizes_[4]; + int paddings_[4]; + + // shape correlative + int input_shape_[4]; + int output_shape_[4]; + int in_stride_[4]; + int out_stride_[4]; + int padded_in_shape_[4]; + + // other parameter + bool need_paddings_; + int m_; + int data_type_len; +} SpaceToBatchParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c new file mode 100644 index 00000000..c1ef67cc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.c @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/errorcode.h" + +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h new file mode 100644 index 00000000..874a2cf9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sparse_to_dense_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ +#define MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ + +#include "nnacl_c/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id); +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c new file mode 100644 index 00000000..0336d6da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/splice_fp32.h" +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < splice_parameter->context_dim_; ++off) { + int r_off = splice_parameter->forward_indexes_[forward_index]; + forward_index++; + const float *tmp_src_data = src_data + r_off * src_col; + float *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float)); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h new file mode 100644 index 00000000..83c937c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/splice_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SPLICE_FP32_H_ +#define NNACL_FP32_SPLICE_FP32_H_ + +#include +#include "nnacl_c/splice_parameter.h" + +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col); + +#endif // NNACL_FP32_SPLICE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c new file mode 100644 index 00000000..5dc4d390 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#include "nnacl_c/fp32/squared_difference.h" +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" + +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size) { + ElementSub(in0, in1, out, size); + return ElementMul(out, out, out, size); +} + +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale) { + ElementOptSub(in0, in1, out, size, scale); + return ElementMul(out, out, out, size); +} +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h new file mode 100644 index 00000000..2d4db9de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/squared_difference.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Element Squared Difference */ +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size); +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c new file mode 100644 index 00000000..9fda1af4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/errorcode.h" + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice) { + if (strided_slice->in_shape_size_ > DIMENSION_8D) { + return NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D; + } + + int32_t begins[DIMENSION_8D]; + int32_t ends[DIMENSION_8D]; + int32_t strides[DIMENSION_8D]; + int32_t input_shape[DIMENSION_8D]; + int32_t i; + for (i = 0; i < strided_slice->in_shape_size_; ++i) { + begins[i] = strided_slice->begins_[i]; + ends[i] = MSMIN(strided_slice->ends_[i], strided_slice->in_shape_[i]); + strides[i] = strided_slice->strides_[i]; + input_shape[i] = strided_slice->in_shape_[i]; + } + + int32_t real_index = strided_slice->in_shape_size_ - 1; + for (i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + strided_slice->begins_[i] = begins[real_index]; + strided_slice->ends_[i] = ends[real_index]; + strided_slice->strides_[i] = strides[real_index]; + strided_slice->in_shape_[i] = input_shape[real_index--]; + } else { + strided_slice->begins_[i] = 0; + strided_slice->ends_[i] = 1; + strided_slice->strides_[i] = 1; + strided_slice->in_shape_[i] = 1; + } + } + strided_slice->in_shape_size_ = DIMENSION_8D; + return NNACL_OK; +} + +bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } + +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice) { + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + const uint8_t *in = (const uint8_t *)input; + uint8_t *out = (uint8_t *)output; + int data_type_size = (int)DataTypeCSize(strided_slice->data_type_); + + int32_t *begins = strided_slice->begins_; + int32_t *ends = strided_slice->ends_; + int32_t *strides = strided_slice->strides_; + int32_t *in_shape = strided_slice->in_shape_; + + int dim_offset[DIMENSION_8D - 1]; + dim_offset[6] = in_shape[7]; + dim_offset[5] = in_shape[6] * dim_offset[6]; + dim_offset[4] = in_shape[5] * dim_offset[5]; + dim_offset[3] = in_shape[4] * dim_offset[4]; + dim_offset[2] = in_shape[3] * dim_offset[3]; + dim_offset[1] = in_shape[2] * dim_offset[2]; + dim_offset[0] = in_shape[1] * dim_offset[1]; + size_t out_offset = 0; + int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; + for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { + for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { + for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { + for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { + for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { + for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { + for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { + for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { + int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + + dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + + dim6 * dim_offset[6] + dim7; + memcpy(out + out_offset * data_type_size, in + in_offset * data_type_size, data_type_size); + out_offset++; + } + } + } + } + } + } + } + } + return NNACL_OK; +} + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset) { + if (stride == 1) { + size_t unit = split_len * inner_size; + for (size_t i = 0; i < outer; ++i) { + memcpy(output, input, unit); + output += unit; + input += in_offset; + } + return; + } + for (size_t i = 0; i < outer; ++i) { + const uint8_t *input_ptr = input + i * in_offset; + for (int j = 0; j < split_len; ++j) { + memcpy(output, input_ptr, inner_size); + output += inner_size; + input_ptr += inner_size * stride; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h new file mode 100644 index 00000000..57ab5ce8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/strided_slice_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_STRIDED_SLICE_FP32_H_ +#define NNACL_FP32_STRIDED_SLICE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/kernel/strided_slice.h" +#ifdef __cplusplus +extern "C" { +#endif + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice); +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice); + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_STRIDED_SLICE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c new file mode 100644 index 00000000..6975d41c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.c @@ -0,0 +1,150 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/sub_fp32.h" +#include "nnacl_c/sub_fp32_simd.h" +#include "nnacl_c/errorcode.h" + +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] - in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] - in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] - in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSub(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSub, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] - in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[index], 0), 6); + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h new file mode 100644 index 00000000..66063fbf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SUB_FP32_H_ +#define MINDSPORE_NNACL_SUB_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementSub(const float *in0, const float *in1, float *out, int size); +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementSubRelu(const float *in0, const float *in1, float *out, int size); +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SUB_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in new file mode 100644 index 00000000..36bc85e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/sub_fp32_simd.h.in @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptSubNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + + +static inline int ElementOptSubExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { +SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); +SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1_opt_, valpha); +SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSub@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c new file mode 100644 index 00000000..ad268d8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.c @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/topk_fp32.h" +#include "nnacl_c/errorcode.h" + +int DescendCmp(const void *a, const void *b) { + NNACL_CHECK_NULL_RETURN_ERR(a); + NNACL_CHECK_NULL_RETURN_ERR(b); + float sub = ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +int IndexSortCmp(const void *a, const void *b) { + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + float *cur_input_data = (float *)input_data; + float *cur_output_data = (float *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} + +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int32_t *cur_input_data = (int32_t *)input_data; + int32_t *cur_output_data = (int32_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = (float)(*(cur_input_data + offset)); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = (int)(top_map[m].element); + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h new file mode 100644 index 00000000..8adfd5be --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/topk_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_TOPK_H_ +#define MINDSPORE_NNACL_TOPK_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TopkNode { + float element; + int32_t index; +} TopkNode; + +typedef struct TopkParameter { + // primitive parameter + OpParameter op_parameter_; + int k_; + int axis_; + bool sorted_; + + // other parameter + int dim_size_; + int outer_loop_num_; + int inner_loop_num_; + void *topk_node_list_; +} TopkParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_TOPK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c new file mode 100644 index 00000000..550df6dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/op_base.h" + +void TransposeDim2Fp32(const float *in_data, float *out_data, const int32_t *strides, int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void TransposeDim3Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_m = n * out_stride4; + int stride4_m = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_m + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_m + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + + int data_size = (*out_strides) * output_shape[0]; + int offset_size = UP_DIV(data_size, thread_num); + int task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (int idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + TransposeDim2Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + TransposeDim3Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + TransposeDim4Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + TransposeDim5Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + TransposeDim6Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h new file mode 100644 index 00000000..3cca6000 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_H_ + +#include +#include +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes); +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c new file mode 100644 index 00000000..44bc0de2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.c @@ -0,0 +1,239 @@ +#ifdef BFC_MEMORY +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/transpose_server_fp32.h" + +#define JUDGEPART(NUM) \ + if (dim_start##NUM == overflow_point##NUM) { \ + dim_start##NUM = 0; \ + } else { \ + ++dim_start##NUM; \ + in_offset += stride##NUM; \ + continue; \ + } + +void DoTransposeServerDim3(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride2 = strides[THIRD_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t last_dim = overflow_point2 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point2; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride2; + } + out_data[i + overflow_point2] = in_data[in_offset]; + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } +} + +void DoTransposeServerDim4(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t last_dim = overflow_point3 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point3; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride3; + } + out_data[i + overflow_point3] = in_data[in_offset]; + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } +} + +void DoTransposeServerDim5(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t last_dim = overflow_point4 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point4; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride4; + } + out_data[i + overflow_point4] = in_data[in_offset]; + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } +} + +void DoTransposeServerDim6(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride5 = strides[SIXTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t dim_start4 = boundary_info->start_dim[FIFTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t overflow_point5 = overflow_points[SIXTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t last_dim = overflow_point5 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point5; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride5; + } + out_data[i + overflow_point5] = in_data[in_offset]; + JUDGEPART(4) + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } +} + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info) { + if (axis_num == DIMENSION_3D) { + DoTransposeServerDim3(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_4D) { + DoTransposeServerDim4(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_5D) { + DoTransposeServerDim5(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_6D) { + DoTransposeServerDim6(in_data, out_data, overflow_points, strides, boundary_info); + return; + } + out_data += boundary_info->out_start_offset; + int64_t stride = strides[axis_num - 1]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } + int64_t dim_info[MAX_TRANSPOSE_DIM_SIZE] = {}; + for (int i = 0; i < axis_num; ++i) { + dim_info[i] = boundary_info->start_dim[i]; + } + int64_t last_overflow_point = overflow_points[axis_num - 1]; + int64_t last_dim = last_overflow_point + 1; + out_data += size; + size = boundary_info->sizes[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < last_overflow_point; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride; + } + out_data[i + last_overflow_point] = in_data[in_offset]; + int j = axis_num - 2; + while (dim_info[j] == overflow_points[j]) { + dim_info[j] = 0; + --j; + } + ++dim_info[j]; + in_offset += strides[j]; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h new file mode 100644 index 00000000..1c1be31b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/transpose_server_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ + +#ifdef BFC_MEMORY +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TransposeBlockBoundaryInfo { + int64_t out_start_offset; + int64_t sizes[C3NUM]; + int64_t in_offsets[C2NUM]; + int64_t start_dim[MAX_TRANSPOSE_DIM_SIZE]; +} TransposeBlockBoundaryInfo; + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info); +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c new file mode 100644 index 00000000..15de2a44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.c @@ -0,0 +1,179 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width) { + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + for (size_t i = 0; i < input_tensor->shape_size_; i++) { + if (input_tensor->shape_[i] <= 0) { + return NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR; + } + } + + size_t input_hw_dims = Num2; + NNACL_CHECK_FALSE(input_tensor->shape_size_ < DIMENSION_2D, NNACL_TRIU_INPUT_DIMS_INVALID); + + *mul = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - input_hw_dims; i++) { + *mul *= input_tensor->shape_[i]; + } + *height = input_tensor->shape_[input_tensor->shape_size_ - Num2]; + *width = input_tensor->shape_[input_tensor->shape_size_ - Num1]; + + return NNACL_OK; +} + +int TriuTrilGetKValue(KernelBase *self, int64_t *k) { + if (self->in_size_ <= 1) { + *k = 0; + return NNACL_OK; + } + + TensorC *k_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(k_tensor); + NNACL_CHECK_NULL_RETURN_ERR(k_tensor->data_); + + switch (k_tensor->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: + *k = *((int32_t *)k_tensor->data_); + break; + case kNumberTypeInt64: + *k = *((int64_t *)k_tensor->data_); + break; + default: + return NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h new file mode 100644 index 00000000..24205877 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/triu_tril_fp32.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ +#define MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width); +int TriuTrilGetKValue(KernelBase *self, int64_t *k); + +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); + +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c new file mode 100644 index 00000000..d0716dd1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/unique_fp32.h" + +int Find(const float *array, int len, float target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = Find(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} + +int FindInt(const int32_t *array, int len, int target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindInt(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h new file mode 100644 index 00000000..3b503922 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/unique_fp32.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_UNIQUE_H +#define MINDSPORE_NNACL_UNIQUE_H + +#include "nnacl_c/op_base.h" + +typedef struct UniqueParameter { + // primitive parameter + OpParameter op_parameter_; +} UniqueParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1); +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_UNIQUE_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c new file mode 100644 index 00000000..efcc3530 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/where_fp32.h" +#include "nnacl_c/common_func.h" + +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h new file mode 100644 index 00000000..d1112c0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/where_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_WHERE_FP32_H_ +#define MINDSPORE_NNACL_FP32_WHERE_FP32_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/where_parameter.h" +#include "nnacl_c/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_WHERE_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c new file mode 100644 index 00000000..e3be43d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.c @@ -0,0 +1,2233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless re256uired by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/fp32/winograd_avx.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[16]; + LoadAvx16Data; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_SUB256_F32(src[offset], src[2 + offset]); + t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]); + t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = MS_SUB256_F32(t[offset], t[2 + offset]); + m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]); + m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[36]; + MS_FLOAT32X8 m[36]; + LoadAvx36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)), + src[4 + offset]); + t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADD256_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUB256_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)), + src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]); + m[l] = + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADD256_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUB256_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)), + t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[64]; + MS_FLOAT32X8 m[64]; + LoadAvx64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)), + MS_MUL256_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5)); + t[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)), + MS_MUL256_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)), + MS_MUL256_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25)); + m[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5)); + m[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25)); + m[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)), + MS_MUL256_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } +} + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + if (real_c == C8NUM) { + InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step); + } else { + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 18] = MS_MIN256_F32(six, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 24] = MS_MIN256_F32(six, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + m[l + 30] = MS_MIN256_F32(six, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 7] = MS_MIN256_F32(six, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 14] = MS_MIN256_F32(six, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 21] = MS_MIN256_F32(six, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 28] = MS_MIN256_F32(six, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 35] = MS_MIN256_F32(six, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + m[l + 42] = MS_MIN256_F32(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h new file mode 100644 index 00000000..a9843129 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_avx.h @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifndef MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#define MINDSPORE_NNACL_WINOGRAD_AVX_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#define LoadAvx16Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); + +#define LoadAvx36Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); + +#define LoadAvx64Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); \ + src[36] = MS_LD256_F32(src_data + 36 * src_step); \ + src[37] = MS_LD256_F32(src_data + 37 * src_step); \ + src[38] = MS_LD256_F32(src_data + 38 * src_step); \ + src[39] = MS_LD256_F32(src_data + 39 * src_step); \ + src[40] = MS_LD256_F32(src_data + 40 * src_step); \ + src[41] = MS_LD256_F32(src_data + 41 * src_step); \ + src[42] = MS_LD256_F32(src_data + 42 * src_step); \ + src[43] = MS_LD256_F32(src_data + 43 * src_step); \ + src[44] = MS_LD256_F32(src_data + 44 * src_step); \ + src[45] = MS_LD256_F32(src_data + 45 * src_step); \ + src[46] = MS_LD256_F32(src_data + 46 * src_step); \ + src[47] = MS_LD256_F32(src_data + 47 * src_step); \ + src[48] = MS_LD256_F32(src_data + 48 * src_step); \ + src[49] = MS_LD256_F32(src_data + 49 * src_step); \ + src[50] = MS_LD256_F32(src_data + 50 * src_step); \ + src[51] = MS_LD256_F32(src_data + 51 * src_step); \ + src[52] = MS_LD256_F32(src_data + 52 * src_step); \ + src[53] = MS_LD256_F32(src_data + 53 * src_step); \ + src[54] = MS_LD256_F32(src_data + 54 * src_step); \ + src[55] = MS_LD256_F32(src_data + 55 * src_step); \ + src[56] = MS_LD256_F32(src_data + 56 * src_step); \ + src[57] = MS_LD256_F32(src_data + 57 * src_step); \ + src[58] = MS_LD256_F32(src_data + 58 * src_step); \ + src[59] = MS_LD256_F32(src_data + 59 * src_step); \ + src[60] = MS_LD256_F32(src_data + 60 * src_step); \ + src[61] = MS_LD256_F32(src_data + 61 * src_step); \ + src[62] = MS_LD256_F32(src_data + 62 * src_step); \ + src[63] = MS_LD256_F32(src_data + 63 * src_step); + +#define StoreAvx4Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define StoreAvx9Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define StoreAvx16Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define StoreAvx25Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + 4 * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c new file mode 100644 index 00000000..49960ef4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.c @@ -0,0 +1,281 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s, + int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float))); + } + + // get real input block with padding + if (real_c == channel_tile) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_AVX + MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); +#else + for (int k = 0; k < channel_tile; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } // interval x loop + } // interval y loop + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } // interval x loop + } // interval y loop + } +} + +// fp32 conv winograd +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int tile_num = C12NUM; + int dst_ic4_offset = dst_plane_offset + ic * channel_tile; + int dst_step = tile_num * in_channel; + float *trans_input_ptr = trans_input + dst_ic4_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int channel_tile = C4NUM; + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * channel_tile; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int block_tile = C12NUM; + int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile; + size_t dst_step = (size_t)(input_unit * block_tile * channel_tile); + float *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); +#ifndef ENABLE_AVX + // avx is write to nc4hw4 + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#else + // avx is write to nc8hw8 + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_plane = output_w * output_h; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; +#ifdef ENABLE_AVX + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + output_plane * j) * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#else + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = (dst_tile_offset + output_plane * j) * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h new file mode 100644 index 00000000..38c19c19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_transform.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ +#define MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/winograd_utils.h" + +#ifdef __cplusplus +extern "C" { +#endif +// for fp32 winograd input/output transform +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func); + +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func); + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c new file mode 100644 index 00000000..345cf646 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.c @@ -0,0 +1,4289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_avx.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/errorcode.h" + +#ifdef ENABLE_ARM64 +void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) { + float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1)); + float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1)); + float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3)); + float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3)); + *s0 = (float32x4_t)(vtrn1q_f64(m0, m2)); + *s2 = (float32x4_t)(vtrn2q_f64(m0, m2)); + *s1 = (float32x4_t)(vtrn1q_f64(m1, m3)); + *s3 = (float32x4_t)(vtrn2q_f64(m1, m3)); +} +#endif + +#ifdef ENABLE_AVX +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit, + OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit, + OutputTransform6x2AvxUnit, OutputTransform6x3AvxUnit, OutputTransform6x4AvxUnit, + OutputTransform6x5AvxUnit, OutputTransform6x2ReluAvxUnit, OutputTransform6x3ReluAvxUnit, + OutputTransform6x4ReluAvxUnit, OutputTransform6x5ReluAvxUnit, OutputTransform6x2Relu6AvxUnit, + OutputTransform6x3Relu6AvxUnit, OutputTransform6x4Relu6AvxUnit, OutputTransform6x5Relu6AvxUnit, + OutputTransform8x2AvxUnit, OutputTransform8x3AvxUnit, OutputTransform8x4AvxUnit, + OutputTransform8x5AvxUnit, OutputTransform8x6AvxUnit, OutputTransform8x7AvxUnit, + OutputTransform8x2ReluAvxUnit, OutputTransform8x3ReluAvxUnit, OutputTransform8x4ReluAvxUnit, + OutputTransform8x5ReluAvxUnit, OutputTransform8x6ReluAvxUnit, OutputTransform8x7ReluAvxUnit, + OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit, + OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit}; +#else +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit, + OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit, + OutputTransform6x4Unit, OutputTransform6x5Unit, OutputTransform6x2ReluUnit, OutputTransform6x3ReluUnit, + OutputTransform6x4ReluUnit, OutputTransform6x5ReluUnit, OutputTransform6x2Relu6Unit, OutputTransform6x3Relu6Unit, + OutputTransform6x4Relu6Unit, OutputTransform6x5Relu6Unit, OutputTransform8x2Unit, OutputTransform8x3Unit, + OutputTransform8x4Unit, OutputTransform8x5Unit, OutputTransform8x6Unit, OutputTransform8x7Unit, + OutputTransform8x2ReluUnit, OutputTransform8x3ReluUnit, OutputTransform8x4ReluUnit, OutputTransform8x5ReluUnit, + OutputTransform8x6ReluUnit, OutputTransform8x7ReluUnit, OutputTransform8x2Relu6Unit, OutputTransform8x3Relu6Unit, + OutputTransform8x4Relu6Unit, OutputTransform8x5Relu6Unit, OutputTransform8x6Relu6Unit, OutputTransform8x7Relu6Unit}; +#endif + +InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFunc InputTransStepFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step}; + +static InputTransPackFunc InputTransPackFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12}; + +InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; } + +InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; } +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[16]; + + src[0] = MS_LDQ_F32(src_data); + src[1] = MS_LDQ_F32(src_data + src_step); + src[2] = MS_LDQ_F32(src_data + 2 * src_step); + src[3] = MS_LDQ_F32(src_data + 3 * src_step); + + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + t[l] = MS_SUBQ_F32(src[offset], src[2 + offset]); + src[offset + 4] = MS_LDQ_F32(src_data + (offset + 4) * src_step); + t[4 + l] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + src[offset + 5] = MS_LDQ_F32(src_data + (offset + 5) * src_step); + t[8 + l] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + src[offset + 6] = MS_LDQ_F32(src_data + (offset + 6) * src_step); + t[12 + l] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + src[offset + 7] = MS_LDQ_F32(src_data + (offset + 7) * src_step); + } + + int offset = 3 * 4; + t[3] = MS_SUBQ_F32(src[offset], src[2 + offset]); + t[7] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[11] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + t[15] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + + src[0] = MS_SUBQ_F32(t[0], t[2]); + src[1] = MS_ADDQ_F32(t[1], t[2]); + src[2] = MS_SUBQ_F32(t[2], t[1]); + src[3] = MS_SUBQ_F32(t[3], t[1]); + + for (int l = 1; l < 4; ++l) { + offset = l * 4; + src[offset] = MS_SUBQ_F32(t[offset], t[2 + offset]); + MS_STQ_F32(dst_data + (l - 1) * dst_step, src[offset - 4]); + src[offset + 1] = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_STQ_F32(dst_data + (3 + l) * dst_step, src[offset - 3]); + src[offset + 2] = MS_SUBQ_F32(t[2 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (7 + l) * dst_step, src[offset - 2]); + src[offset + 3] = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (11 + l) * dst_step, src[offset - 1]); + } + + MS_STQ_F32(dst_data + 3 * dst_step, src[12]); + MS_STQ_F32(dst_data + dst_step * 7, src[13]); + MS_STQ_F32(dst_data + dst_step * 11, src[14]); + MS_STQ_F32(dst_data + dst_step * 15, src[15]); + + } else { +#endif + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + const float *src_ptr = src_data + l * 4 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1); + MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + } +#else + float src[4]; + float m[4]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 4; ++l) { + for (int w = 0; w < 4; ++w) { + int tmp_index = l * 4 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 4; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20); + MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(s10, s20); + m1 = MS_ADDQ_F32(s11, s21); + m2 = MS_ADDQ_F32(s12, s22); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s20, s10); + m1 = MS_SUBQ_F32(s21, s11); + m2 = MS_SUBQ_F32(s22, s12); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s30, s10); + m1 = MS_SUBQ_F32(s31, s11); + m2 = MS_SUBQ_F32(s32, s12); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 4; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[4]; + float m[4]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 4; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + for (int w = 0; w < 4; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[36]; + MS_FLOAT32X4 m[36]; + Load36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(src[4 + offset], src[2 + offset]); + t[l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 4), MS_MULQ_N_F32(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADDQ_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUBQ_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + t[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[1 + offset], 4), MS_MULQ_N_F32(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(t[4 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 4), MS_MULQ_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADDQ_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUBQ_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + m[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[1 + offset], 4), MS_MULQ_N_F32(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } + } else { +#endif + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + const float *src_ptr = src_data + l * 6 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2); + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4)); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + } +#else + float src[6]; + float m[6]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 6; ++l) { + for (int w = 0; w < 6; ++w) { + int tmp_index = l * 6 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 6; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42)); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32)); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 6; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[6]; + float m[6]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 6; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + for (int w = 0; w < 6; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[64]; + MS_FLOAT32X4 m[64]; + Load64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = + MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 0.5625), MS_MULQ_N_F32(src[2 + offset], 3.0625)), + MS_MULQ_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 1.125), MS_MULQ_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 2.25), MS_MULQ_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.5625), MS_MULQ_N_F32(src[4 + offset], 2.5)); + t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.375), MS_MULQ_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.25), MS_MULQ_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], -0.5625), MS_MULQ_N_F32(src[3 + offset], 3.0625)), + MS_MULQ_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 0.5625), MS_MULQ_N_F32(t[2 + offset], 3.0625)), + MS_MULQ_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 1.125), MS_MULQ_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 2.25), MS_MULQ_N_F32(t[4 + offset], 3.25)); + m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.5625), MS_MULQ_N_F32(t[4 + offset], 2.5)); + m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.375), MS_MULQ_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.25), MS_MULQ_N_F32(t[4 + offset], 1.25)); + m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], -0.5625), MS_MULQ_N_F32(t[3 + offset], 3.0625)), + MS_MULQ_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } +} +#endif + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + InputTransform8x8Unit_block4(src_data, dst_data, src_step, dst_step); + } else { +#endif + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + const float *src_ptr = src_data + l * 8 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step); + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25)); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25)); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m7 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + MS_STQ_F32(dst_ptr + 6 * dst_step, m6); + MS_STQ_F32(dst_ptr + 7 * dst_step, m7); + } +#else + float src[8]; + float m[8]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 8; ++l) { + for (int w = 0; w < 8; ++w) { + int tmp_index = l * 8 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 8; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + LOAD_LINE_DATA(6); + LOAD_LINE_DATA(7); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60); + MS_FLOAT32X4 m1 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61); + MS_FLOAT32X4 m2 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5)); + MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5)); + MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5)); + MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25)); + MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25)); + MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5)); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5)); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5)); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70); + m1 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71); + m2 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72); + MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 8; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[8]; + float m[8]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 8; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + for (int w = 0; w < 8; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) { + if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) { + return NULL; + } + int in_index = (input_unit - 4) / 2; + int index = 0; + for (int i = 0; i < in_index; i++) { + index += ((i * 2 + 4) - 2) * 3; + } + int act_index; + if (act_type == ActType_Relu) { + act_index = 1; + } else if (act_type == ActType_Relu6) { + act_index = 2; + } else { + act_index = 0; + } + return OutputTransFuncList[index + (input_unit - 2) * act_index + output_unit - 2]; +} + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 18] = MS_MINQ_F32(six, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 24] = MS_MINQ_F32(six, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + m[l + 30] = MS_MINQ_F32(six, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 7] = MS_MINQ_F32(six, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 14] = MS_MINQ_F32(six, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 21] = MS_MINQ_F32(six, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 28] = MS_MINQ_F32(six, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 35] = MS_MINQ_F32(six, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + m[l + 42] = MS_MINQ_F32(six, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h new file mode 100644 index 00000000..00d4705b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32/winograd_utils.h @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_UTILS_H_ +#define MINDSPORE_NNACL_WINOGRAD_UTILS_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFuncList { + InputTransFunc in_func_; + InputTransStepFunc in_step_func_; + InputTransPackFunc in_pack_func_; + OutputTransFunc out_func_; +} TransFuncList; + +#define Load16Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); + +#define Load36Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); + +#define Load64Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); \ + src[36] = MS_LDQ_F32(src_data + 36 * src_step); \ + src[37] = MS_LDQ_F32(src_data + 37 * src_step); \ + src[38] = MS_LDQ_F32(src_data + 38 * src_step); \ + src[39] = MS_LDQ_F32(src_data + 39 * src_step); \ + src[40] = MS_LDQ_F32(src_data + 40 * src_step); \ + src[41] = MS_LDQ_F32(src_data + 41 * src_step); \ + src[42] = MS_LDQ_F32(src_data + 42 * src_step); \ + src[43] = MS_LDQ_F32(src_data + 43 * src_step); \ + src[44] = MS_LDQ_F32(src_data + 44 * src_step); \ + src[45] = MS_LDQ_F32(src_data + 45 * src_step); \ + src[46] = MS_LDQ_F32(src_data + 46 * src_step); \ + src[47] = MS_LDQ_F32(src_data + 47 * src_step); \ + src[48] = MS_LDQ_F32(src_data + 48 * src_step); \ + src[49] = MS_LDQ_F32(src_data + 49 * src_step); \ + src[50] = MS_LDQ_F32(src_data + 50 * src_step); \ + src[51] = MS_LDQ_F32(src_data + 51 * src_step); \ + src[52] = MS_LDQ_F32(src_data + 52 * src_step); \ + src[53] = MS_LDQ_F32(src_data + 53 * src_step); \ + src[54] = MS_LDQ_F32(src_data + 54 * src_step); \ + src[55] = MS_LDQ_F32(src_data + 55 * src_step); \ + src[56] = MS_LDQ_F32(src_data + 56 * src_step); \ + src[57] = MS_LDQ_F32(src_data + 57 * src_step); \ + src[58] = MS_LDQ_F32(src_data + 58 * src_step); \ + src[59] = MS_LDQ_F32(src_data + 59 * src_step); \ + src[60] = MS_LDQ_F32(src_data + 60 * src_step); \ + src[61] = MS_LDQ_F32(src_data + 61 * src_step); \ + src[62] = MS_LDQ_F32(src_data + 62 * src_step); \ + src[63] = MS_LDQ_F32(src_data + 63 * src_step); + +#define LOAD_LINE_DATA(line) \ + MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \ + MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \ + MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile); + +#define TRANSPOSE_12x4 \ + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \ + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \ + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \ + MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \ + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \ + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \ + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \ + MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \ + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \ + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \ + MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \ + MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \ + transpose4(&s0, &s3, &s6, &s9); \ + transpose4(&s1, &s4, &s7, &s10); \ + transpose4(&s2, &s5, &s8, &s11); \ + MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \ + MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \ + MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \ + MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \ + MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \ + MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \ + MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \ + MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \ + MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \ + MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \ + MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \ + MS_STQ_F32(src_ptr + 11 * pack_tile, s11); + +InputTransFunc GetInputTransFunc(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFunc GetInputTransStepFunc(int input_unit); + +InputTransPackFunc GetInputTransPackFunc(int input_unit); +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); + +#define Store4Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + 4 * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnit(const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c new file mode 100644 index 00000000..c84f5317 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/activation_grad_simd.h" + +int ReluGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t mask_4 = vcleq_f32(src1_4, zero_4); + float32x4_t dst_4 = vbslq_f32(mask_4, zero_4, src0_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst) { + size_t i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + float32x4_t six_4 = vdupq_n_f32(6.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t gt_4 = vcgtq_f32(src1_4, zero_4); + uint32x4_t le_4 = vcleq_f32(src1_4, six_4); + uint32x4_t mask_4 = vandq_u32(gt_4, le_4); + float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i]; + } + return NNACL_OK; +} + +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; + } + return NNACL_OK; +} + +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} + +int SoftplusGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#if defined(ENABLE_AVX) + for (; i <= length - C8NUM; i += C8NUM) { + simd_exp256(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src1 + i))), dst + i); + MS_ST256_F32(dst + i, + MS_DIV256_F32(MS_LD256_F32(src0 + i), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= length - C4NUM; i += C4NUM) { + simd_exp128(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src1 + i)), dst + i); + MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_LDQ_F32(src0 + i), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); + } +#endif + + for (; i < length; ++i) { + simd_exp32(-src1[i], dst + i); + dst[i] = src0[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} + +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h new file mode 100644 index 00000000..ab2933c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/errorcode.h" + +typedef struct ActivationGradParameter { + OpParameter op_parameter; + int type_; + float alpha_; +} ActivationGradParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int ReluGrad(const float *src0, const float *src1, int length, float *dst); +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst); +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst); +int SoftplusGrad(const float *src, const float *src1, int length, float *dst); +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in new file mode 100644 index 00000000..1f98839b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/activation_grad_simd.h.in @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ShrinkGrad@SIMD_INSTRUCTION@(int index, const float *src0, const float *src1, + int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src0_t = SIMD_LD_F32(src0 + index); + SIMD_F32 src1_t = SIMD_LD_F32(src1 + index); + + SIMD_MASK mask0 = SIMD_CMPLE_F32(src1_t, pos_lamdb_v); + SIMD_MASK mask1 = SIMD_CMPLE_F32(neg_lamdb_v, src1_t); + SIMD_MASK mask = SIMD_AND_MASK(mask0, mask1); + + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src0_t, SIMD_MOV_F32(0.0f), mask)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c new file mode 100644 index 00000000..92d494f0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.c @@ -0,0 +1,48 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/apply_proximal_adagrad_fp32_simd.h" + +int Sign(float x) { + if (x > 0) { + return 1; + } + if (x < 0) { + return -1; + } + return 0; +} + +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements) { + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(ApplyProximalAdagradOpt, i, var, accum, lr, l1, l2, grad, input_elements); + + for (; i < input_elements; ++i) { + accum[i] += grad[i] * grad[i]; + float learning_rate = lr / sqrt(accum[i]); + float prox_v = var[i]; + prox_v -= grad[i] * learning_rate; + + if (l1 > 0) { + var[i] = Sign(prox_v) * fmax(fabs(prox_v) - learning_rate * l1, 0.0) / (1 + l2 * learning_rate); + } else { + var[i] = prox_v / (1 + l2 * learning_rate); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h new file mode 100644 index 00000000..00bc67a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in new file mode 100644 index 00000000..70a43429 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalAdagradOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float *accum, float lr, float l1, float l2, float *grad, int64_t size) { + SIMD_F32 lr_vec = SIMD_MOV_F32(lr); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec1 = SIMD_LD_F32(grad + index); + SIMD_F32 accum_vec = SIMD_LD_F32(accum + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + accum_vec = SIMD_FMADD_F32(tmp_vec1, tmp_vec1, accum_vec); + SIMD_F32 learn_rate_vec = SIMD_DIV_F32(lr_vec, SIMD_SQRT_F32(accum_vec)); + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(tmp_vec1, learn_rate_vec)); + SIMD_ST_F32(accum + index, accum_vec); + tmp_vec1 = SIMD_FMADD_F32(l2_vec, learn_rate_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + learn_rate_vec = SIMD_MUL_F32(learn_rate_vec, l1_vec); + learn_rate_vec = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), learn_rate_vec); + learn_rate_vec = SIMD_MAX_F32(learn_rate_vec, SIMD_MOV_F32(0.0f)); + learn_rate_vec = SIMD_DIV_F32(learn_rate_vec, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(learn_rate_vec, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c new file mode 100644 index 00000000..82c66245 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.c @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/apply_proximal_gradient_descent_fp32_simd.h" + +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, + int64_t input_elements) { + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ApplyProximalGradientDescentOpt, i, var, alpha, l1, l2, delta, input_elements); + for (; i < input_elements; ++i) { + float prox_v = var[i]; + prox_v -= delta[i] * alpha; + + if (l1 > 0) { + var[i] = SignFp32(prox_v) * fmax(fabs(prox_v) - alpha * l1, 0.0) / (1 + l2 * alpha); + } else { + var[i] = prox_v / (1 + l2 * alpha); + } + } +} + +float SignFp32(const float x) { + if (x > 0.0) { + return 1.0; + } + if (x < 0.0) { + return -1.0; + } + return 0.0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h new file mode 100644 index 00000000..3e519010 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, int64_t input_elements); +float SignFp32(const float x); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in new file mode 100644 index 00000000..f885665b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in @@ -0,0 +1,64 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalGradientDescentOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float alpha, float l1, float l2, float *delta, int64_t size) { + SIMD_F32 alpha_vec = SIMD_MOV_F32(alpha); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 delta_vec = SIMD_LD_F32(delta + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(delta_vec, alpha_vec)); + SIMD_F32 tmp_vec1 = SIMD_FMADD_F32(l2_vec, alpha_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + SIMD_F32 tmp_vec2 = SIMD_MUL_F32(alpha_vec, l1_vec); + tmp_vec2 = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), tmp_vec2); + tmp_vec2 = SIMD_MAX_F32(tmp_vec2, SIMD_MOV_F32(0.0f)); + tmp_vec2 = SIMD_DIV_F32(tmp_vec2, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(tmp_vec2, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c new file mode 100644 index 00000000..0cb6d201 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/arithmetic_grad.h" +#include +#include +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/errorcode.h" + +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[i] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h new file mode 100644 index 00000000..b46b72dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/arithmetic_grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size); +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size); +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c new file mode 100644 index 00000000..3787acc5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.c @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl_c/fp32_grad/batch_norm_grad.h" + +void var2Invar(float *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = 1.0f / sqrtf(save_var[i] + eps); + } +} + +static void backwardComputeDx(const float *in, const float *yt, const float *mean, const float *invar, + const float *scale, int size, int ch, const float *dbias, const float *dscale, float *dx, + float N, bool is_train) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + dx[ix] = yt[ix]; + if (is_train) { + dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N; + } + dx[ix] *= scale[c] * invar[c]; + } + } +} + +#ifdef _MSC_VER +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train) { +#else +void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(size); + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} + +#ifdef _MSC_VER +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale) { +#else +void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale) { +#endif + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } +} + +#ifdef _MSC_VER +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train) { +#else +void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict dscale, const float *restrict dbias, + const float *restrict scale, int size, int total_size, int ch, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h new file mode 100644 index 00000000..a8c03fb4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_grad.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ +#define CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ + +#include "nnacl_c/fp32_grad/batch_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2Invar(float *save_var, int size, float eps); +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train); +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale); +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train); +#ifdef __cplusplus +} +#endif + +#endif // CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h new file mode 100644 index 00000000..271a5acd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/batch_norm_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ +#define NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct BNGradParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; +} BNGradParameter; + +#endif // NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c new file mode 100644 index 00000000..ed422911 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.c @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" + +static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x, + const float *input_y, const float *weight, float *loss, float *tmp_loss, + bool weight_defined) { + const float epsilon = 1e-12; + + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } + } else { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } + } +} + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined) { + loss[0] = 0.0f; + BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined); + if (reduction != Reduction_None) { + if (input_size % 2 == 1) { + tmp_loss[0] += tmp_loss[input_size - 1]; + } + for (int stride = input_size / 2; stride > 0; stride = stride / 2) { + for (int i = 0; i < stride; i++) { + tmp_loss[i] += tmp_loss[i + stride]; + } + if (stride > 2 && stride % 2 == 1) { + tmp_loss[0] += tmp_loss[stride - 1]; + } + } + loss[0] += tmp_loss[0]; + if (reduction == Reduction_Mean) { + loss[0] /= input_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h new file mode 100644 index 00000000..348d130f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_H_ +#define NNACL_BINARY_CROSS_ENTROPY_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c new file mode 100644 index 00000000..c4d17754 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.c @@ -0,0 +1,56 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined) { + const float epsilon = 1e-12f; + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } + } else { + float dloss1 = dloss[0]; + if (reduction == Reduction_Mean) { + dloss1 = dloss[0] / input_size; + } + for (int i = 0; i < input_size; i++) { + if (weight_defined) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } else { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h new file mode 100644 index 00000000..bc8c3dbb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/binary_cross_entropy_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ +#define NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c new file mode 100644 index 00000000..b6de2e76 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.c @@ -0,0 +1,380 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/convolution_grad_filter.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM +static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + float32x4_t sum_12x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + + float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12); + float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12); + sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0]; + dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1]; + dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2]; + dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3]; + } + return i_c; +} + +static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 12) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + i_c += 12; + } + return i_c; +} + +static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 8) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + i_c += 8; + } + return i_c; +} +static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 4) { + float32x4_t sum_4 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy); + sum_4 = vmlaq_f32(sum_4, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3]; + i_c += 4; + } + return i_c; +} + +static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 2) { + float32x2_t sum_2 = vdup_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x2_t x_4 = vld1_f32(x_addr + offset_x); + float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy); + sum_2 = vmla_f32(sum_2, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1]; + i_c += 2; + } + return i_c; +} +#endif +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_ARM + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h new file mode 100644 index 00000000..1ed95e72 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_filter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c new file mode 100644 index 00000000..c791ed91 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.c @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/convolution_grad_input.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#ifdef ENABLE_ARM +#include +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float *c = dx + j; + const float *mat_b_0 = w + (j + C0NUM) * k_spatial; + const float *mat_b_1 = w + (j + C1NUM) * k_spatial; + const float *mat_b_2 = w + (j + C2NUM) * k_spatial; + const float *mat_b_3 = w + (j + C3NUM) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; +#ifdef ENABLE_ARM + float32x4_t mat_a = vld1q_f32(a); +#else + float mat_a[C4NUM] = {a[C0NUM], a[C1NUM], a[C2NUM], a[C3NUM]}; +#endif + int output_row = (si) / out_w; + int output_col = (si) % out_w; + for (int k = 0; k < k_spatial; k++) { + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM + float32x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; + float32x4_t mat_c = vld1q_f32(c + offset); + mat_c = vmlaq_f32(mat_c, mat_b, mat_a); + vst1q_f32(c + offset, mat_c); +#else + c[offset + C0NUM] += mat_a[C0NUM] * mat_b_0[k]; + c[offset + C1NUM] += mat_a[C1NUM] * mat_b_1[k]; + c[offset + C2NUM] += mat_a[C2NUM] * mat_b_2[k]; + c[offset + C3NUM] += mat_a[C3NUM] * mat_b_3[k]; +#endif + } + } + } + } + + for (; j < end; j++) { + float *c = dx + j; + const float *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + c[offset] += a[0] * b[k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h new file mode 100644 index 00000000..734b53df --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/convolution_grad_input.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c new file mode 100644 index 00000000..c04aee36 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.c @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/dropout_grad.h" + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h new file mode 100644 index 00000000..1338ec78 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP32_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h new file mode 100644 index 00000000..51015ed3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/dropout_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ +#define NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float ratio_; +} DropoutParameter; + +#endif // NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c new file mode 100644 index 00000000..8f51b499 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.c @@ -0,0 +1,855 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/gemm.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#ifdef _MSC_VER +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride) { +#else +void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { +#endif + const float *src_ptr = v1; + float *dst_ptr = v2; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSize(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + int res1 = UP_ROUND(col, round) * row; + return (res > res1) ? res : res1; +} + +int MatSizeTotal(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM32 + const int num0 = C4NUM; +#elif ENABLE_AVX + const int num0 = C6NUM; +#else + const int num0 = C12NUM; +#endif + +#ifdef ENABLE_AVX + const int num1 = C16NUM; +#else + const int num1 = C8NUM; +#endif + int res = MatSize(row, deep, num0) + MatSize(col, deep, num1); + if (stride > 0) res += row * stride; + return res; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Row4MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / 4; + int cm8 = c % 4; + dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; + } + } +} +#endif + +void RowMajor2Row8MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + } + return; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col12MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +void RowMajor2Col12MajorStrideArm32(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} +#endif // ENABLE_ARM32 + +#ifndef ENABLE_ARM32 +void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C12NUM; + int cm8 = c % C12NUM; + dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; + } + } + return; +} + +void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + + /* 12x4 row-major to col-major */ +#ifdef ENABLE_ARM64 + RowMajor2Col12MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 + RowMajor2Col12MajorStrideArm32(src_c, dst_c, lead); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C12NUM * lead; + dst_r += C12NUM * col; + } + + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < row_up_12; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} +#endif + +#ifdef ENABLE_ARM64 +static void RowMajor2Col8MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + /* 8x8 row-major to col-major */ + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +static void RowMajor2Col8MajorStrideArm32(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} + +#else +static void RowMajor2Col8MajorStrideArm32Nnie(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif // SUPPORT_NNIE +#endif // ENABLE_ARM32 + +void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + size_t col_skip = col / C8NUM * C8NUM; + size_t skip_size = C8NUM; +#else + size_t col_skip = col / C4NUM * C4NUM; + size_t skip_size = C4NUM; +#endif + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C8NUM) { + size_t ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 +#ifndef SUPPORT_NNIE + RowMajor2Col8MajorStrideArm32(src_c, dst_c, lead); +#else + RowMajor2Col8MajorStrideArm32Nnie(src_c, dst_c, lead); +#endif +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C8NUM * lead; + dst_r += C8NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Col4MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C4NUM * C4NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C4NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + + /* 4x4 row-major to col-major */ +#ifdef ENABLE_ARM32 + size_t stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [stride] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#else + for (int tr = 0; tr < C4NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = src[c]; + if (offset > max) { + max = offset; + } + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = 0.0f; + if (offset > max) { + max = offset; + } + } + } + return; +} + +void RowMajor2Col6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int totalRow = UP_ROUND(row, C6NUM); + int row6 = row / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + lead); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * lead); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * lead); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * lead); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * lead); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + C4NUM, res1); + _mm_storeu_ps(dst_c + C8NUM, res2); + _mm_storeu_ps(dst_c + C12NUM, res3); + _mm_storeu_ps(dst_c + C16NUM, res4); + _mm_storeu_ps(dst_c + C20NUM, res5); + _mm_storeu_ps(dst_c + C24NUM, res6); + _mm_storeu_ps(dst_c + C28NUM, res7); + _mm_storeu_ps(dst_c + C32NUM, res8); + _mm_storeu_ps(dst_c + C36NUM, res9); + _mm_storeu_ps(dst_c + C40NUM, res10); + _mm_storeu_ps(dst_c + C44NUM, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C6NUM * lead; + dst_r += C6NUM * col; + } + + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int row16 = row / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, lead, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * lead, dst_c + C8NUM, lead, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C16NUM * lead; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + } + return; +} +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace) { + GemmCb gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + gcb.mat_a = NULL; + gcb.mat_b = NULL; + GemmMatmulPlus(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *gcb) { +#ifdef ENABLE_ARM32 + const int num = C4NUM; + const int num1 = C8NUM; +#elif ENABLE_AVX + const int num = C6NUM; + const int num1 = C16NUM; +#else + const int num = C12NUM; + const int num1 = C8NUM; +#endif + float *output = mat_c; + float *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float *mat_a_input = (float *)mat_a; + float *mat_b_input = (float *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + if (ta) { + fworkspace += MatSize(K, M, num); +#ifdef ENABLE_ARM32 + RowMajor2Row4MajorStride(mat_a, mat_a_input, K, M, lda); +#elif ENABLE_AVX + RowMajor2Row6MajorStride(mat_a, mat_a_input, K, M, lda); +#else + RowMajor2Row12MajorStride(mat_a, mat_a_input, K, M, lda); +#endif + } else { + fworkspace += MatSize(M, K, num); +#ifdef ENABLE_ARM32 + RowMajor2Col4MajorStride(mat_a, mat_a_input, M, K, lda); +#elif ENABLE_AVX + RowMajor2Col6MajorStride(mat_a, mat_a_input, M, K, lda); +#else + RowMajor2Col12MajorStride(mat_a, mat_a_input, M, K, lda); +#endif + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + if (tb) { + fworkspace += MatSize(N, K, num1); +#ifdef ENABLE_AVX + RowMajor2Col16MajorStride(mat_b, mat_b_input, N, K, ldb); +#else + RowMajor2Col8MajorStride(mat_b, mat_b_input, N, K, ldb); +#endif + } else { + fworkspace += MatSize(K, N, num1); +#ifdef ENABLE_AVX + RowMajor2Row16MajorStride(mat_b, mat_b_input, K, N, ldb); +#else + RowMajor2Row8MajorStride(mat_b, mat_b_input, K, N, ldb); +#endif + } + } + if (incremental) output = fworkspace; +#ifdef ENABLE_ARM32 + MatmulFloatNeon32Opt(mat_a_input, mat_b_input, output, gcb->bias, (int)gcb->atype, K, M, N, ldc, 1); +#else + MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); +#endif + if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h new file mode 100644 index 00000000..9d2f6e49 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/gemm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_GEMM_H_ +#define NNACL_FP32_GRAD_GEMM_H_ + +#include +#include "nnacl_c/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float *bias; + float *mat_a; + float *mat_b; +} GemmCb; + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *cb); +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace); +int MatSize(int row, int col, int round); +int MatSizeTotal(int row, int col, int deep, int inc); +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_GEMM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c new file mode 100644 index 00000000..ce0bce92 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/layernorm_grad.h" +#include +#include +#include "nnacl_c/errorcode.h" + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { + // var is actually layer_norm forward output var + const float eps = 1e-12; + const float *var_sqrt_rev = var; + if (block_size <= 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (int j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = dgamma; + db[i] = dbeta; + } + for (int i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float dxm = x[index] - mean[i]; + int param_shift = index % param_num; + float dyg = dy[index] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[i] + eps, -1.5); + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float var_sqrt = pow(var_sqrt_rev[i] + eps, -0.5); + int param_shift = index % param_num; + float dx1 = dy[index] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[index] - mean[i]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[index] = dx1 + dx2 + dx3; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h new file mode 100644 index 00000000..8558c448 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernorm_grad.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h new file mode 100644 index 00000000..ebb3c8b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/layernormgrad_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ +#define NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormGradParameter; + +#endif // NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c new file mode 100644 index 00000000..8400af94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" +#include +#include +#include "nnacl_c/lstm_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32_grad/gemm.h" +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/nnacl_utils.h" + +static const int num_of_gates = 4; +static const int no_of_temp_matrices_sized_output_step = 5; + +static inline float *AllocteFromScrachPad(float **scrach_pad, int size) { + float *buffer = *scrach_pad; + *scrach_pad += size; + return buffer; +} + +static const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order +static const int weights_order_IFGO[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order + +const int *getLstmOrderIOFG(void) { return weights_order_IOFG; } + +const int *getLstmOrderIFGO(void) { return weights_order_IFGO; } + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * row; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col * row_align; +#ifdef ENABLE_AVX + RowMajor2Row16Major(src_batch, dst_batch, row, col); +#elif defined(ENABLE_ARM32) + RowMajor2Row4Major(src_batch, dst_batch, row, col); +#else + RowMajor2Row8Major(src_batch, dst_batch, row, col); +#endif + } +} + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order) { + int matrix_size = col * row; + for (int i = 0; i < nof_martices; i++) { + const float *src_block = src + i * matrix_size; + float *dst_block = dst + ((order == NULL) ? i : order[i]) * matrix_size; + memcpy(dst_block, src_block, matrix_size * sizeof(float)); + } +} + +void sumCols(int m, int n, int stride, float *inMat, float *outMat, bool accumulate) { + for (int idn = 0; idn < n; idn++) { + float *col = inMat + idn; + if (!accumulate) { + *outMat = 0; + } + for (int idm = 0; idm < m; idm++) { + *outMat += *col; + col += stride; + } + outMat++; + } +} + +int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) { + int workspace_size, temp; + // if the appropriate GemmMatNul use beta>0 matSizeTotal must have col as last parameter. + workspace_size = MatSizeTotal(batch, input_size, hidden_size, input_size); + temp = MatSizeTotal(batch, hidden_size, hidden_size, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, input_size, batch, input_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, hidden_size, batch, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + return workspace_size; +} + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + int workspace_size = no_of_temp_matrices_sized_output_step * time_stamp_len; + workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_); + return workspace_size; +} + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + return no_of_temp_matrices_sized_output_step * time_stamp_len; +} + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param) { + int dir_mult = lstm_param->bidirectional_ ? C2NUM : C1NUM; + for (int b = 0; b < lstm_param->batch_; b++) { + int batch_offset = b * dir_mult * lstm_param->hidden_size_; + float *dy = src + batch_offset; + memcpy(dst + b * lstm_param->hidden_size_, dy, lstm_param->hidden_size_ * sizeof(float)); + } +} + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param) { + float *scratchPad = workspace; + + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *temp0 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp1 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp2 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp3 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp4 = AllocteFromScrachPad(&scratchPad, seq_len); + + // Accumulate gradients into dH + ElementAdd(dH, dY, dH, seq_len); + + ElementMul(dH, output_gate, temp1, seq_len); + Tanh(cell_state, seq_len, temp0); + ElementMul(temp0, temp0, temp2, seq_len); + ElementMul(temp1, temp2, temp4, seq_len); + ElementSub(temp1, temp4, temp1, seq_len); + ElementAdd(dC, temp1, dC, seq_len); + + // calculate dI, dO, dF and dG + float *dI = temp1; // dI = dC_{t} * G + ElementMul(dC, cell_gate, dI, seq_len); + float *dO = temp2; // dO = dH * Tanh(C_{t}) + ElementMul(dH, temp0, dO, seq_len); + float *dF = temp3; // dF = dC_{t} * C_{t-1} + ElementMul(dC, prev_cell_state, dF, seq_len); + float *dG = temp4; // dG = dC_{t} * I + ElementMul(dC, input_gate, dG, seq_len); + + // dAi = dI * I * (1 - I) + float *dAi = temp1; + *dA = dAi; + ElementMul(dI, input_gate, dAi, seq_len); + ElementMul(dAi, input_gate, temp0, seq_len); + ElementSub(dAi, temp0, dAi, seq_len); + + // dAo = dO * O * (1 - O) + float *dAo = temp2; + ElementMul(dO, output_gate, dAo, seq_len); + ElementMul(dAo, output_gate, temp0, seq_len); + ElementSub(dAo, temp0, dAo, seq_len); + + // dAf = dF * F * (1 - F) + float *dAf = temp3; + ElementMul(dF, forget_gate, dAf, seq_len); + ElementMul(dAf, forget_gate, temp0, seq_len); + ElementSub(dAf, temp0, dAf, seq_len); + + float *dAg = temp4; + ElementMul(cell_gate, cell_gate, temp0, seq_len); + ElementMul(dG, temp0, temp0, seq_len); + ElementSub(dG, temp0, dAg, seq_len); + + float *mat_workspace = AllocteFromScrachPad( + &scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *weights_loop = w; + float *dA_loop = dAi; // dAi, dAo, dAf, dAg + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_; + dA_loop += seq_len; + } + + // calculate dH next + size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float); + memset(dH, 0, dH_size); + dA_loop = dAi; + weights_loop = v; + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_; + dA_loop += seq_len; + } + // calculate dC next + ElementMul(dC, forget_gate, dC, seq_len); +} + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param) { + // Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *mat_workspace = AllocteFromScrachPad( + &workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *dA_loop = dA; // dAi, dAo, dAf, dAg + int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_; + int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_; + int dB_size = 0; + float *dW_loop = dW; + float *dV_loop = dV; + float *dB_loop = 0; + if (lstm_param->has_bias_) { + dB_loop = dB; + dB_size = lstm_param->hidden_size_; + } + + for (int idx = 0; idx < num_of_gates; idx++) { + // Calc dW + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_, + mat_workspace); + // Calc dV + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, prev_hidden_state, lstm_param->hidden_size_, 1.0, dV_loop, + lstm_param->hidden_size_, mat_workspace); + // Clac dB + if (dB_loop != 0) { + sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); + } + dA_loop += seq_len; + dW_loop += dW_size; + dV_loop += dV_size; + dB_loop += dB_size; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h new file mode 100644 index 00000000..12bc3758 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/lstm_grad_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_LSTM_GRAD_H_ +#define NNACL_FP32_GRAD_LSTM_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LstmGradParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int has_bias_; +} LstmGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +const int *getLstmOrderIOFG(void); + +const int *getLstmOrderIFGO(void); + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param); + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param); + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param); + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order); + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order); + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param); + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LSTM_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c new file mode 100644 index 00000000..5ffdb880 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.c @@ -0,0 +1,147 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/maxpool_grad_grad.h" +#include "nnacl_c/errorcode.h" + +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args) { + const int channel = args->input_channel_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_height = param->stride_h_; + const int stride_width = param->stride_w_; + + const int pad_top = param->pad_u_; + const int pad_left = param->pad_l_; + + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_chw = channel * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_chw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_chw; + const int pos_c = pos / output_hw % channel; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = pos_n * channel * input_height * input_width + pos_c * input_height * input_width; + int max_idx = h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} + +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args) { + PoolingParameter *param_2d = (PoolingParameter *)(param); + const int channel = args->input_channel_; + const int input_depth = param->input_d_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_depth = param->window_d_; + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_depth = param->stride_d_; + const int stride_height = param_2d->stride_h_; + const int stride_width = param_2d->stride_w_; + + const int pad_front = param->pad_f_; + const int pad_top = param_2d->pad_u_; + const int pad_left = param_2d->pad_l_; + + const int output_depth = param->output_d_; + NNACL_CHECK_ZERO_RETURN_ERR(output_depth); + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_cdhw = channel * output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_cdhw); + const int output_dhw = output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_dhw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_cdhw; + const int pos_c = pos / output_dhw % channel; + const int pos_d = pos / output_hw % output_depth; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int d_start = pos_d * stride_depth - pad_front; + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int d_end = MSMIN(d_start + window_depth, input_depth); + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + d_start = MSMAX(d_start, 0); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = + pos_n * channel * input_depth * input_height * input_width + pos_c * input_depth * input_height * input_width; + int max_idx = d_start * input_height * input_width + h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int d_cur = d_start; d_cur < d_end; ++d_cur) { + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = d_cur * input_height * input_width + h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h new file mode 100644 index 00000000..9edeef77 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/maxpool_grad_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ +#define NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args); + +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c new file mode 100644 index 00000000..63c452f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/nllloss_grad_fp32.h" + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" + +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type) { + if (logits == NULL || loss_grad == NULL || labels == NULL || weight == NULL || total_weight == NULL || + logits_grad == NULL) { + return NNACL_NULL_PTR; + } + + memset(logits_grad, 0, batch * class_num * sizeof(float)); + for (int i = 0; i < batch; i++) { + int index = i * class_num + labels[i]; + float n_weight = weight[labels[i]]; + if (reduction_type == Reduction_Sum) { + logits_grad[index] = -loss_grad[0] * n_weight; + } else if (reduction_type == Reduction_Mean) { + logits_grad[index] = -loss_grad[0] * n_weight / *total_weight; + } else { + logits_grad[index] = -loss_grad[i] * n_weight; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h new file mode 100644 index 00000000..c49bf25f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/nllloss_grad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h new file mode 100644 index 00000000..835c40d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/optimizer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_OPTIMIZER_H_ +#define NNACL_FP32_GRAD_OPTIMIZER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; + float grad_scale_; +} ApplyMomentumParameter; + +typedef struct { + OpParameter op_parameter_; + float dampening_; + bool use_nesterov_; + float weight_decay_; +} SgdParameter; + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; +} AdamParameter; + +#endif // NNACL_FP32_GRAD_OPTIMIZER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c new file mode 100644 index 00000000..235b8599 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.c @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/fp32_grad/pack_ext.h" + +void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, + int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *data_col = in_data[offset]; + data_col++; + } else { + *data_col = 0; + data_col++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(data_col, in_data + offset, sizeof(float) * channels); + data_col += channels; + } else { + memset(data_col, 0, sizeof(float) * channels); + data_col += channels; + } + } + } + } + } +} + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); +} + +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->output_h_; + const int in_width = conv_param->output_w_; + + const int output_w = conv_param->input_w_; + + const int tot_channels = conv_param->output_channel_; + const int channels = tot_channels / conv_param->group_; + int channel, kernel_row, kernel_col, output_rows, output_col; + for (channel = 0; channel < channels; channel++) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + for (output_rows = start; output_rows < start + rows; output_rows++) { + int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; + if (!((unsigned)(input_row) < (unsigned)(in_height))) { + for (output_col = output_w; output_col; output_col--) { + *(data_row++) = 0; + } + } else { + int input_col = -pad_left + kernel_col * dilation_w; + for (output_col = output_w; output_col; output_col--) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels + channel; + *(data_row++) = in_data[offset]; + } else { + *(data_row++) = 0; + } + input_col += stride_w; + } + } + } + } + } + } +} + +void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col, output_rows, output_col; + + int row_stride_offset = 0; + + for (output_rows = output_h; output_rows; output_rows--) { + int col_stride_offset = 0; + for (output_col = output_w; output_col; output_col--) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + col_stride_offset += stride_w; + } + row_stride_offset += stride_h; + } +} + +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h new file mode 100644 index 00000000..a29ae204 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pack_ext.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_PACK_EXT_H_ +#define NNACL_FP32_GRAD_PACK_EXT_H_ + +#include +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_PACK_EXT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c new file mode 100644 index 00000000..2ba7604a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.c @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/pooling_grad.h" +#include +#include +#include +#include "nnacl_c/op_base.h" + +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float kk = 1.0f / (float)(win_h * win_w); +#if ENABLE_ARM + const float32x4_t factor = vdupq_n_f32(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * output_h * output_w * channel; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + float32x4_t in = vld1q_f32(inPtr + idx); + float32x4_t delta = vmulq_f32(in, factor); +#else + float delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_ARM + float *out_vec = out + (xw + in_w * xh) * channel + ic; + float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); + float32x4_t outs = vaddq_f32(outr, delta); + vst1q_f32(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_ARM +static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { + uint32x4_t res = vcgtq_f32(in, *max); + int32x4_t m_index = vbslq_s32(res, index, prev_index); + *max = vbslq_f32(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + for (int ib = 0; ib < output_batch; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * in_h * in_w * channel; + const float *dyPtr = dy_ptr + ib * output_h * output_w * channel; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic <= channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t max_idx = vdupq_n_u32(0); + float32x4_t max_val = vdupq_n_f32(-FLT_MAX); + float32x4_t delta = vld1q_f32(dyPtr + idx); +#else + float delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float max_val[C4NUM] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX}; + int max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float32x4_t in = vld1q_f32(inPtr + val_idx); + max_idx = vreinterpretq_u32_s32( + MaxIndex(in, &max_val, vreinterpretq_s32_u32(index), vreinterpretq_s32_u32(max_idx))); +#else + float val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float max_val = -FLT_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h new file mode 100644 index 00000000..7f1e12eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/pooling_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_POOLING_GRAD_H_ +#define NNACL_FP32_GRAD_POOLING_GRAD_H_ + +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c new file mode 100644 index 00000000..11d670a3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.c @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/fp32_grad/reduce_grad.h" +#include "nnacl_c/fp32_grad/utils.h" +#include "nnacl_c/op_base.h" + +void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims, + const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) { + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = (size_t)(output_dims[idx]); + num_outputs *= current; + } + + // Reset input iterator. + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(input_num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(input_num_dims, input_dims, input_iter, num_axes, axes); + output_data[output_offset] += input_data[input_offset]; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_axes; ++idx) { + size_t current = (size_t)(input_dims[axes[idx]]); + num_elements_in_axis *= current; + } + + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = output_data[idx] / (float)(num_elements_in_axis); + } +} + +float ReduceMeanAll(const float *src, int size) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += src[i]; + } + return sum / size; +} + +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { + int num_outputs = 1; + int same_shape = 1; + for (int idx = 0; idx < num_dims; ++idx) { + num_outputs *= output_dims[idx]; + if (output_dims[idx] != input_dims[idx]) same_shape = 0; + } + if (same_shape) { + memcpy(output, input, (size_t)(num_outputs) * sizeof(float)); + return; + } + + memset(output, 0, (size_t)(num_outputs) * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes[C5NUM] = {0}; + int num_axes = 0; + for (int i = 0; i < num_dims; i++) { + if (output_dims[i] == C1NUM && num_axes < C5NUM) { + axes[num_axes++] = i; + } + } + + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(num_dims, input_dims, input_iter, num_axes, axes); + output[output_offset] += input[input_offset]; + } while (NextIndex(num_dims, input_dims, input_iter)); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h new file mode 100644 index 00000000..edb61025 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/reduce_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_REDUCE_GRAD_H_ +#define NNACL_FP32_GRAD_REDUCE_GRAD_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif +float ReduceMeanAll(const float *src, int size); +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_REDUCE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c new file mode 100644 index 00000000..678abd47 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/fp32_grad/resize_grad.h" +#include +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/errorcode.h" + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + float in_y = (float)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + + float in_x = (float)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + out_hw_size = out_height * out_width; + in_hw_size = in_height * in_width; + + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float in_y = (float)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float in_x = (float)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h new file mode 100644 index 00000000..b0f65a60 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_GRAD_H_ + +#include "nnacl_c/fp32_grad/resize_grad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h new file mode 100644 index 00000000..2d0b4629 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/resize_grad_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ResizeGradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float height_scale_; + float width_scale_; +} ResizeGradParameter; + +#endif // NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h new file mode 100644 index 00000000..7c46b4b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/smooth_l1_loss.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ +#define NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float beta_; +} SmoothL1LossParameter; + +#endif // NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c new file mode 100644 index 00000000..594ef6de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.c @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h" +#include + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size) { + float eps = 1e-6; + if (grads != NULL) { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + grads[i * number_of_classes + j] = (logits[i * number_of_classes + j] - labels[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } else { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h new file mode 100644 index 00000000..7cd53bc4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h new file mode 100644 index 00000000..57a03069 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_crossentropy_parameter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SoftmaxCrossEntropyParameter { + // primitive parameter + OpParameter op_parameter_; + int n_dim_; + + // shape correlative + int input_shape_[5]; + + // other parameter + int32_t batch_size_; + unsigned int number_of_classes_; + bool is_grad_; +} SoftmaxCrossEntropyParameter; + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c new file mode 100644 index 00000000..bdc6e750 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.c @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis) { + int dim = 1; + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + NNACL_CHECK_ZERO_RETURN(outter_size); + for (int i = 0; i < inner_size * input_shape[axis]; i++) sum_mul[i] = 1.0; + for (int i = 0; i < n_dim; i++) dim *= input_shape[i]; + dim /= outter_size; + memcpy(output_ptr, yt_ptr, (size_t)(ele_size) * sizeof(float)); + + const int M = input_shape[axis]; + const int N = inner_size; + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * dim; + memset(sum_data, 0, (size_t)(inner_size) * sizeof(float)); + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + for (int j = 0; j < input_shape[axis]; j++) { + int offset = inner_offset + j * inner_size; + sum_data[k] += output_ptr[offset] * input_ptr[offset]; + } + } + for (int k = 0; k < M; ++k) { + float a = -sum_mul[k]; + for (int j = 0; j < N; ++j) { + *(output_ptr + outter_offset + k * N + j) += a * sum_data[j]; + } + } + } + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] *= input_ptr[i]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h new file mode 100644 index 00000000..69f9b331 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ +#define NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ + +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c new file mode 100644 index 00000000..24c62ea3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/softmax_grad_utils.h" +#include +#include +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +void ExpFp32Offset(const float *src, float *dst, float sub_bias, int num) { + int i = 0; +#ifdef ENABLE_ARM64 + int count = (num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + MS_FLOAT32X4 input = vld1q_f32(src + i); + MS_FLOAT32X4 bias = vdupq_n_f32(sub_bias); + MS_FLOAT32X4 i1 = vsubq_f32(input, bias); + simd_exp128(i1, dst + i); + } +#endif + for (; i < num; ++i) { + simd_exp32(src[i] - sub_bias, dst + i); + } +} + +// output = exp(input) / reduce_sum(exp(input), axis) +static void SoftMaxP1Simple(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, + int length) { + for (int i = start; i < start + count; i++) { + int inner_offset = i * length; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + ExpFp32Offset(input_ptr + inner_offset, output_ptr + inner_offset, max_data, length); + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + _sum_data += output_ptr[axis_offset]; + } + sum_data[i] = _sum_data; + } +} + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size) { + if (inner_size == 1) { + SoftMaxP1Simple(input_ptr, output_ptr, sum_data, start, count, length); + return; + } + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); + } + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + _sum_data += output_ptr[axis_offset]; + } + sum_data[k + sum_outter_offset] = _sum_data; + } + } +} + +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size) { + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < length; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h new file mode 100644 index 00000000..68aac00d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/softmax_grad_utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_H_ +#define NNACL_FP32_GRAD_SOFTMAX_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size); +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c new file mode 100644 index 00000000..4a9fe765 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.c @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_grad/strided_slice_grad.h" +#include "nnacl_c/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[((size_t)(i) + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_8D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + const int *s = param->strides_; + const int *b = param->begins_; + for (int i = 0; i < DIMENSION_8D; i++) { + size *= (size_t)(param->in_shape_[i]); + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C7NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C6NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C5NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C4NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C3NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C2NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C1NUM, C6NUM, pos); + size_t p = CalcIndex(param->in_shape_, C0NUM, C7NUM, pos); + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] * dx_shape[C7NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * + dx_shape[C7NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] * dx_shape[C7NUM] + (o * s[C6NUM] + b[C6NUM]) * dx_shape[C7NUM] + + (p * s[C7NUM] + b[C7NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h new file mode 100644 index 00000000..e779e6de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/strided_slice_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h new file mode 100644 index 00000000..4798715f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_grad/utils.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_UTILS_H_ +#define NNACL_FP32_GRAD_UTILS_H_ + +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t GetInputOffset(int num_dims, const int *dims, const int *iter) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + + return offset; +} + +static inline size_t GetOutputOffset(int num_dims, const int *dims, const int *iter, int num_axis, const int *axes) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + int is_axis = 0; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axes[axis_idx]) { + is_axis = 1; + break; + } + } + + if (is_axis == 0) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + } + return offset; +} + +static inline int NextIndex(int num_dims, const int *dims, int *current) { + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c new file mode 100644 index 00000000..53c2fe92 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h" +#ifdef ENABLE_ARM64 +#include +#endif + +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride) { +#ifndef ENABLE_ARM64 + return; +#else + // mul-acc + for (int oc = 0; oc < 8; oc++) { + uint32_t cur_nnz = nnz[oc]; + // init 8x1 C with bias + float32x4_t vacc1 = vld1q_dup_f32(bias + oc); + float32x4_t vacc2 = vacc1; + for (uint32_t nz = 0; nz < cur_nnz; nz++) { + // load w + float32x4_t vw = vld1q_dup_f32(b++); + // load 8 inputs + const float *input = a + (*(dmap++) / sizeof(float)); + float32x4_t vi1 = vld1q_f32(input); + float32x4_t vi2 = vld1q_f32(input + 4); + vacc1 = vfmaq_f32(vacc1, vi1, vw); + vacc2 = vfmaq_f32(vacc2, vi2, vw); + } + // save output + *(c + oc) = vacc1[0]; + *(c + 1 * out_stride + oc) = vacc1[1]; + *(c + 2 * out_stride + oc) = vacc1[2]; + *(c + 3 * out_stride + oc) = vacc1[3]; + *(c + 4 * out_stride + oc) = vacc2[0]; + *(c + 5 * out_stride + oc) = vacc2[1]; + *(c + 6 * out_stride + oc) = vacc2[2]; + *(c + 7 * out_stride + oc) = vacc2[3]; + } +#endif +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h new file mode 100644 index 00000000..d95d235c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/fp32_sparse/matmul_sparse_x1_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_MATMUL_SPARSE_X1_H_ +#define NNACL_FP32_MATMUL_SPARSE_X1_H_ + +#include +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, const float *bias, + ActType act_type, size_t out_stride); +#endif + +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_MATMUL_SPARSE_X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h new file mode 100644 index 00000000..b52fa90d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_nd_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_ND_PARAMETER_H_ +#define NNACL_GATHER_ND_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct { + OpParameter op_parameter_; +} GatherNdParameter; + +#endif // NNACL_GATHER_ND_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h new file mode 100644 index 00000000..8b9d729c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gather_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_PARAMETER_H_ +#define NNACL_GATHER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GatherParameter { + // Primitive parameter + OpParameter op_parameter_; + int axis_; +} GatherParameter; + +#endif // NNACL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h new file mode 100644 index 00000000..fa1e4ca8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gelu_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GELU_PARAMETER_H_ +#define NNACL_GELU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GeLUParameter { + // Primitive parameter + OpParameter op_parameter_; + bool approximate_; +} GeLUParameter; + +#endif // NNACL_GELU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h new file mode 100644 index 00000000..b01912a9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/glu_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GLU_PARAMETER_H_ +#define NNACL_GLU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GluParameter { + OpParameter op_parameter_; + int axis_; +} GluParameter; + +#endif // NNACL_GLU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h new file mode 100644 index 00000000..422162a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/grid_sampler_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRID_SAMPLER_PARAMETER_H_ +#define NNACL_GRID_SAMPLER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GridSamplerParameter { + OpParameter op_parameter_; + int64_t interpolation_mode_; + int64_t padding_mode_; + bool align_corners_; +} GridSamplerParameter; + +#endif // NNACL_GRID_SAMPLER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h new file mode 100644 index 00000000..a733d1b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/group_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GROUP_NORM_PARAMETER_H_ +#define NNACL_GROUP_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +typedef struct GroupNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int num_groups_; + int channel_; + int unit_; + int batch_; + bool affine_; + void *mean_; + void *variance_; +} GroupNormParameter; + +typedef struct GroupNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} GroupNormQuantArg; + +#endif // NNACL_GROUP_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h new file mode 100644 index 00000000..36c84e2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/gru_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRU_PARAMETER_H_ +#define NNACL_GRU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct GruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; +} GruParameter; + +#endif // NNACL_GRU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c new file mode 100644 index 00000000..4ee74d68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/activation_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[0]; + const TensorC *input_grad = inputs[1]; + if (input->shape_size_ != input_grad->shape_size_) { + return NNACL_ERR; + } + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != input_grad->shape_[i]) { + return NNACL_ERR; + } + } + + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +REG_INFER(ActivationGrad, PrimType_ActivationGrad, ActivationGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h new file mode 100644 index 00000000..cb128bd5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/activation_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H +#define MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c new file mode 100644 index 00000000..ebc7ba14 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/adam_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 10); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[2]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[9]) || NNACLGetElementNum(inputs[3]) != 1 || + NNACLGetElementNum(inputs[4]) != 1 || NNACLGetElementNum(inputs[5]) != 1 || NNACLGetElementNum(inputs[6]) != 1 || + NNACLGetElementNum(inputs[7]) != 1 || NNACLGetElementNum(inputs[8]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(Adam, PrimType_Adam, AdamInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h new file mode 100644 index 00000000..b251f1ca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_INFER_H +#define MINDSPORE_NNACL_ADAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c new file mode 100644 index 00000000..fbb8a918 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const size_t expected_inputs_size = 10; + const int var_idx = 0; + const int m_idx = 1; + const int v_idx = 2; + const int lr_idx = 3; + const int beta1_idx = 4; + const int beta2_idx = 5; + const int epsilon = 6; + const int decay_idx = 7; + const int grad_idx = 8; + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[m_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[v_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[grad_idx]) || + NNACLGetElementNum(inputs[lr_idx]) != 1 || NNACLGetElementNum(inputs[beta1_idx]) != 1 || + NNACLGetElementNum(inputs[beta2_idx]) != 1 || NNACLGetElementNum(inputs[epsilon]) != 1 || + NNACLGetElementNum(inputs[decay_idx]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h new file mode 100644 index 00000000..13d5c636 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/adam_weight_decay_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H +#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c new file mode 100644 index 00000000..b84416ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/add_sub_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fillDimNum0 = dy->shape_size_ - x1->shape_size_; + size_t fillDimNum1 = dy->shape_size_ - x2->shape_size_; + size_t j0 = 0; + size_t j1 = 0; + for (size_t i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(AddGrad, PrimType_AddGrad, AddSubGradInferShape) +REG_INFER(SubGrad, PrimType_SubGrad, AddSubGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h new file mode 100644 index 00000000..216edefa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/add_sub_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H +#define MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c new file mode 100644 index 00000000..c2f2a8fd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.c @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/addn_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + if (inputs_size < 2) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + size_t max_dims = input->shape_size_; + size_t max_dims_idx = 0; + + // check zerp dimension + for (size_t i = 0; i < max_dims; i++) { + NNACL_CHECK_FALSE(input->shape_[i] == 0, NNACL_ERR); + } + + // determine max_dims + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ > max_dims) { + max_dims = inputs[i]->shape_size_; + max_dims_idx = i; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[max_dims_idx]->shape_, inputs[max_dims_idx]->shape_size_); + + // make sure all elements have the same size or 1 (broadcasting) in all dimensions + for (size_t i = 1; i < inputs_size; ++i) { + if ((inputs[i]->shape_size_ != max_dims) && + (NNACLGetElementNum(inputs[i]) != NNACLGetElementNum(inputs[max_dims_idx]))) { + return NNACL_ERR; + } + if (inputs[i]->shape_size_ == max_dims) { + for (size_t j = 0; j < max_dims; j++) { + if (inputs[i]->shape_[j] != inputs[max_dims_idx]->shape_[j] && inputs[i]->shape_[j] != 1 && + inputs[max_dims_idx]->shape_[j] != 1) { + return NNACL_ERR; + } + } + } + } + + for (size_t d = 0; d < inputs[max_dims_idx]->shape_size_; ++d) { + size_t max_dim = 0; + for (size_t i = 0; i < inputs_size; ++i) { + size_t shift = max_dims - (size_t)(inputs[i]->shape_size_); + size_t dim = (i < shift) ? 1 : (size_t)(inputs[i]->shape_[d]); + if (dim > max_dim) { + max_dim = dim; + } + } + output->shape_[d] = (int)(max_dim); // set the biggest dimension in the output tensor + } + + return NNACL_OK; +} + +REG_INFER(AddN, PrimType_AddN, AddnInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h new file mode 100644 index 00000000..3655d9c7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/addn_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADDN_INFER_H +#define MINDSPORE_NNACL_ADDN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADDN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c new file mode 100644 index 00000000..b52f7627 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.c @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/affine_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int MatmulInfer(const AffineParameter *param, int a_shape[MAX_SHAPE_SIZE], size_t a_shape_size, + int b_shape[MAX_SHAPE_SIZE], size_t b_shape_size) { + MatMulParameter *matmul_param = param->matmul_parameter_; + NNACL_CHECK_NULL_RETURN_ERR(matmul_param); + if (matmul_param->a_transpose_) { + if (a_shape_size < 2) { + return NNACL_ERR; + } + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); + } + if (matmul_param->b_transpose_) { + if (b_shape_size < 2) { + return NNACL_ERR; + } + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + return NNACL_OK; +} + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + // splice + matmul + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AffineParameter *param = (AffineParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + if (a_shape_size == 4 && a_shape[2] == 1 && a_shape[3] == 1) { + a_shape_size = 2; + SetShapeArray(input0, a_shape, a_shape_size); + } + int context_min = param->context_[0]; + int context_max = param->context_[param->context_size_ - 1]; + + a_shape[1] = input0->shape_[1] - (context_max - context_min); + a_shape[2] = param->output_dim_; + + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(input0, a_shape, a_shape_size); + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + SetShapeArray(input1, b_shape, b_shape_size); + del_end = true; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + if (a_shape[a_shape_size - 3 - i] != b_shape[b_shape_size - 3 - i]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + int ret = MatmulInfer(param, a_shape, a_shape_size, b_shape, b_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + if (c_shape_size < 1 || b_shape_size < 1) { + return NNACL_ERR; + } + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +REG_INFER(Affine, PrimType_Affine, AffineInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h new file mode 100644 index 00000000..04987c21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/affine_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#define MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/affine_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c new file mode 100644 index 00000000..2f97643c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/all_gather_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (inputs_size != 1 || outputs_size != 1) { + return NNACL_NULL_PTR; + } + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AllGatherParameter *param = (AllGatherParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] * param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(AllGather, PrimType_AllGather, AllGatherInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h new file mode 100644 index 00000000..40527aa6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/all_gather_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ +#define MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/all_gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c new file mode 100644 index 00000000..15bb6604 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 5); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + if (out == NULL) { + return NNACL_NULL_PTR; + } + out->data_type_ = inputs[0]->data_type_; + out->format_ = inputs[0]->format_; + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(ApplyMomentum, PrimType_ApplyMomentum, ApplyMomentumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h new file mode 100644 index 00000000..57e3ae10 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/apply_momentum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H +#define MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c new file mode 100644 index 00000000..b36e7fa8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/argmin_max_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 1 || outputs_size > 2) { + return NNACL_ERR; + } + + ArgMinMaxParameter *param = (ArgMinMaxParameter *)parameter; + const TensorC *input = inputs[0]; + TensorC *output_1 = NULL; + TensorC *output_2 = NULL; + if (outputs_size == 2) { + output_1 = outputs[0]; + output_2 = outputs[1]; + } else if (param->out_value_) { + output_2 = outputs[0]; + } else { + output_1 = outputs[0]; + } + + if (output_1 != NULL) { + output_1->format_ = input->format_; + output_1->data_type_ = kNumberTypeInt32; + } + if (output_2 != NULL) { + SetDataTypeFormat(output_2, input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int input_shape_size = (int)input->shape_size_; + int axis = param->axis_ < 0 ? param->axis_ + input_shape_size : param->axis_; + if (axis >= input_shape_size || axis < 0) { + return NNACL_PARAM_INVALID; + } + if (param->topk_ == 1 && !param->keep_dims_) { + int erase_ret = ShapeErase(output_shape, &output_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + output_shape[axis] = param->topk_; + } + + if (output_1 != NULL) { + SetShapeArray(output_1, output_shape, output_shape_size); + } + if (output_2 != NULL) { + SetShapeArray(output_2, output_shape, output_shape_size); + } + return NNACL_OK; +} + +REG_INFER(ArgMin, PrimType_ArgMinFusion, ArgMinMaxInferShape) +REG_INFER(ArgMax, PrimType_ArgMaxFusion, ArgMinMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h new file mode 100644 index 00000000..1afc71fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/argmin_max_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARGMAX_INFER_H +#define MINDSPORE_NNACL_ARGMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARGMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c new file mode 100644 index 00000000..7ad9d70c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int res = ArithmeticInferShape(inputs, inputs_size, outputs, outputs_size, parameter); + TensorC *output = outputs[0]; + if (output == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = kNumberTypeBool; + return res; +} + +REG_INFER(Equal, PrimType_Equal, ArithmeticCompareInferShape) +REG_INFER(Greater, PrimType_Greater, ArithmeticCompareInferShape) +REG_INFER(GreaterEqual, PrimType_GreaterEqual, ArithmeticCompareInferShape) +REG_INFER(Less, PrimType_Less, ArithmeticCompareInferShape) +REG_INFER(LessEqual, PrimType_LessEqual, ArithmeticCompareInferShape) +REG_INFER(NotEqual, PrimType_NotEqual, ArithmeticCompareInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h new file mode 100644 index 00000000..3fce52c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_compare_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H + +#include "nnacl_c/infer/arithmetic_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c new file mode 100644 index 00000000..ed5cdab6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +/* + * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad + * according to the arithmetic_fp32.h now + * the MaximumGrad, MinimumGrad run through MaximumGradInfershape + * the AddGrad, SubGrad run through AddSubGradInfershape + * the others run through this function + * */ +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (dy->shape_size_ > MAX_SHAPE_SIZE || x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + size_t in_shape0_size = 0; + ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); + int in_shape1[MAX_SHAPE_SIZE] = {0}; + size_t in_shape1_size = 0; + ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + if (NNACLGetElementNum(dx1) < NNACLGetElementNum(dx2)) { + param->ndim_ = in_shape1_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! + int j = 0; + for (unsigned int i = 0; i < in_shape1_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape0[j++]; + } + param->in_shape0_[i] = in_shape1[i]; + param->out_shape_[i] = out_shape[i]; + } + } else if (NNACLGetElementNum(dx2) < NNACLGetElementNum(dx1)) { + param->ndim_ = in_shape0_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + param->broadcasting_ = true; + int j = 0; + size_t fill_dim_num = in_shape0_size - in_shape1_size; + for (unsigned int i = 0; i < in_shape0_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape1[j++]; + } + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } else { + param->broadcasting_ = false; + for (unsigned int i = 0; i < in_shape0_size; i++) { + param->in_shape1_[i] = in_shape1[i]; + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + dx1->data_type_ = dy->data_type_; + dx2->data_type_ = dy->data_type_; + return NNACL_OK; +} + +REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape) +REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h new file mode 100644 index 00000000..b7d6bb54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c new file mode 100644 index 00000000..247ec130 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.c @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/arithmetic_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/broadcast_to_infer.h" + +void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) { + output->format_ = input0->format_; + output->data_type_ = input0->data_type_; + // e.g. input0's shape is 1 and input1's shape is 1 15 15 1 + // only regard larger shape size input as the right format input currently + // legacy problem: if input0 infer failed before, its shape is [-1], and input1's shape is [1,2] which need to + // be broadcasted. In this case our program will use input1's format, that's wrong and need to be solved later. + if (input0->data_ != NULL || input0->shape_size_ < input1->shape_size_) { + output->format_ = input1->format_; + } + // when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead + if (((input0->data_ != NULL) && (input1->data_type_ != kTypeUnknown)) || + ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) { + output->data_type_ = input1->data_type_; + } +} + +int BroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + param->broadcasting_ = false; + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + + const int *input_shape0 = input0->shape_; + size_t input_shape0_size = input0->shape_size_; + const int *input_shape1 = input1->shape_; + size_t input_shape1_size = input1->shape_size_; + SetOutputDtypeFormat(input0, input1, output); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = (int)input_shape0_size; + bool has_broad_cast = false; + if (BroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, in_shape1, + output_shape, &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(output, output_shape, ndim); + + param->broadcasting_ = has_broad_cast; + param->ndim_ = (size_t)ndim; + if (ndim > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + memcpy(param->in_shape0_, in_shape0, ndim * sizeof(int)); + memcpy(param->in_shape1_, in_shape1, ndim * sizeof(int)); + memcpy(param->out_shape_, output_shape, ndim * sizeof(int)); + + param->in_elements_num0_ = 1; + param->in_elements_num1_ = 1; + param->out_elements_num_ = 1; + for (int i = 0; i < ndim; i++) { + param->in_elements_num0_ *= param->in_shape0_[i]; + param->in_elements_num1_ *= param->in_shape1_[i]; + param->out_elements_num_ *= param->out_shape_[i]; + } + return NNACL_OK; +} + +REG_INFER(Add, PrimType_AddFusion, ArithmeticInferShape) +REG_INFER(BiasAdd, PrimType_BiasAdd, ArithmeticInferShape) +REG_INFER(Div, PrimType_DivFusion, ArithmeticInferShape) +REG_INFER(Eltwise, PrimType_Eltwise, ArithmeticInferShape) +REG_INFER(FloorDiv, PrimType_FloorDiv, ArithmeticInferShape) +REG_INFER(FloorMod, PrimType_FloorMod, ArithmeticInferShape) +REG_INFER(LogicalAnd, PrimType_LogicalAnd, ArithmeticInferShape) +REG_INFER(LogicalOr, PrimType_LogicalOr, ArithmeticInferShape) +REG_INFER(Maximum, PrimType_Maximum, ArithmeticInferShape) +REG_INFER(Minimum, PrimType_Minimum, ArithmeticInferShape) +REG_INFER(Mod, PrimType_Mod, ArithmeticInferShape) +REG_INFER(Mul, PrimType_MulFusion, ArithmeticInferShape) +REG_INFER(RealDiv, PrimType_RealDiv, ArithmeticInferShape) +REG_INFER(Sub, PrimType_SubFusion, ArithmeticInferShape) +REG_INFER(SquaredDifference, PrimType_SquaredDifference, ArithmeticInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h new file mode 100644 index 00000000..d7a53551 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/arithmetic_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arithmetic_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c new file mode 100644 index 00000000..d4a59496 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assert_op_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return NNACL_OK; +} + +REG_INFER(Assert, PrimType_Assert, AssertOpInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h new file mode 100644 index 00000000..49836662 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assert_op_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSERT_OP_INFER_H +#define MINDSPORE_NNACL_ASSERT_OP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSERT_OP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c new file mode 100644 index 00000000..6de394ef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assign_add_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x = inputs[0]; + const TensorC *y = inputs[1]; + TensorC *out = outputs[0]; + if (x->data_type_ != y->data_type_) { + return NNACL_ERR; + } + SetDataTypeFormat(out, x); + SetShapeTensor(out, x); + return NNACL_OK; +} + +REG_INFER(AssignAdd, PrimType_AssignAdd, AssignAddInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h new file mode 100644 index 00000000..ec5cda41 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_add_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_ADD_INFER_H +#define MINDSPORE_NNACL_ASSIGN_ADD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_ADD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c new file mode 100644 index 00000000..350c55a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/assign_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1])) { + return NNACL_ERR; + } + + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(Assign, PrimType_Assign, AssignInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h new file mode 100644 index 00000000..4d3a8364 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/assign_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_INFER_H +#define MINDSPORE_NNACL_ASSIGN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c new file mode 100644 index 00000000..3f0ee5af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/attention_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/attention_parameter.h" + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + AttentionParameter *param = (AttentionParameter *)parameter; + const TensorC *q_input = inputs[FIRST_INPUT]; + const TensorC *k_input = inputs[SECOND_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, q_input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *q_weight = inputs[FOURTH_INPUT]; + if (q_input->shape_size_ != C2NUM && q_input->shape_size_ != C3NUM) { + return NNACL_ERR; + } + if (q_weight->shape_size_ != C2NUM) { + return NNACL_ERR; + } + int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0]; + int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1]; + int t_seq_len = k_input->shape_[1]; + if (q_input->shape_size_ == C2NUM) { + output0->shape_[FIRST_INPUT] = batch * f_seq; + output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C2NUM; + } else { + output0->shape_[FIRST_INPUT] = batch; + output0->shape_[SECOND_INPUT] = f_seq; + output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C3NUM; + } + if (outputs_size >= C3NUM) { + TensorC *output1 = outputs[SECOND_INPUT]; + SetDataTypeFormat(output1, q_input); + output1->shape_[FIRST_INPUT] = batch; + output1->shape_[SECOND_INPUT] = param->head_num_; + output1->shape_[THIRD_INPUT] = param->head_size_; + output1->shape_[FOURTH_INPUT] = t_seq_len; + output1->shape_size_ = C4NUM; + TensorC *output2 = outputs[THIRD_INPUT]; + SetDataTypeFormat(output2, q_input); + output2->shape_[FIRST_INPUT] = batch; + output2->shape_[SECOND_INPUT] = param->head_num_; + output2->shape_[THIRD_INPUT] = t_seq_len; + output2->shape_[FOURTH_INPUT] = param->head_size_; + output2->shape_size_ = C4NUM; + } + return NNACL_OK; +} + +REG_INFER(Attention, PrimType_Attention, AttentionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h new file mode 100644 index 00000000..f483d03f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/attention_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H +#define MINDSPORE_NNACL_ATTENTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ATTENTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c new file mode 100644 index 00000000..a2ebeba9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.c @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/infer_register.h" + +unsigned Log2Ceil(unsigned length) { + if (length == 0) { + return 0; + } + int floor = 0; + for (int i = 4; i >= 0; --i) { + const unsigned shift = (1 << (unsigned)i); + unsigned tmp = length >> shift; + if (tmp != 0) { + length = tmp; + floor += shift; + } + } + return length == (length & ~(length - 1)) ? floor : floor + 1; +} + +unsigned GetFftLength(unsigned length) { + unsigned shift = Log2Ceil(length); + return 1 << shift; +} + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 2) { + return NNACL_ERR; + } + AudioSpectrogramParameter *param = (AudioSpectrogramParameter *)parameter; + if (param->window_size_ < 2) { + return NNACL_ERR; + } + if (param->stride_ < 1) { + return NNACL_ERR; + } + int output_shape[3]; + output_shape[0] = input->shape_[1]; + int sample_sub_window = input->shape_[0] - param->window_size_; + output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / param->stride_; + // compute fft length + int fft_length = (int)GetFftLength(param->window_size_); + output_shape[2] = fft_length / 2 + 1; + SetShapeArray(output, output_shape, 3); + return NNACL_OK; +} + +REG_INFER(AudioSpectrogram, PrimType_AudioSpectrogram, AudioSpectrogramInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h new file mode 100644 index 00000000..e0b0d6e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/audio_spectrogram_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H +#define MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct AudioSpectrogramParameter { + OpParameter op_parameter_; + int window_size_; + int stride_; +} AudioSpectrogramParameter; + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c new file mode 100644 index 00000000..87607774 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.c @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SetOutputShapeFromParam(const TensorC *const *inputs, TensorC **outputs, const OpParameter *parameter) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + + const BatchToSpaceParameter *param = (const BatchToSpaceParameter *)parameter; + const int32_t *block_shape = param->block_shape_; + const int32_t *crops = param->crops_; + int mul_block_shape = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return NNACL_ERR; + } + mul_block_shape *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SetOutputShapeFromInput(const TensorC *const *inputs, TensorC **outputs) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + int *block_shape = (int *)(inputs[1]->data_); + int *crops = (int *)(inputs[2]->data_); + if (NNACLGetElementNum(inputs[1]) != 2) { + return NNACL_PARAM_INVALID; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_PARAM_INVALID; + } + int mul_block_shape_ = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return 1; + } + mul_block_shape_ *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape_) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape_ == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape_; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1 || (inputs_size != 1 && inputs_size != 3)) { + return NNACL_PARAM_INVALID; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + ret = SetOutputShapeFromParam(inputs, outputs, parameter); + return ret; + } + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + ret = SetOutputShapeFromInput(inputs, outputs); + return ret; +} + +REG_INFER(BatchToSpace, PrimType_BatchToSpace, BatchToSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h new file mode 100644 index 00000000..fa073047 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/batch_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H +#define MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/batch_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c new file mode 100644 index 00000000..a698b139 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/bias_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0->shape_size_ > MAX_SHAPE_SIZE || in0->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + int inshape[] = {in0->shape_[in0->shape_size_ - 1]}; + size_t inshape_size = 1; + SetDataTypeFormat(out, in0); + SetShapeArray(out, inshape, inshape_size); + + return NNACL_OK; +} + +REG_INFER(BiasAddGrad, PrimType_BiasAddGrad, BiasGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h new file mode 100644 index 00000000..0ab82b54 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bias_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BIAS_GRAD_INFER_H +#define MINDSPORE_NNACL_BIAS_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BIAS_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c new file mode 100644 index 00000000..bb45ff28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *x = inputs[0]; + TensorC *out = outputs[0]; + SetDataTypeFormat(out, x); + BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter; + ReductionType reduction = (ReductionType)(param->reduction); + if (reduction == Reduction_Mean || reduction == Reduction_Sum) { + out->shape_size_ = 1; + out->shape_[0] = 1; + } else { + SetShapeTensor(out, x); + } + return NNACL_OK; +} + +REG_INFER(BinaryCrossEntropy, PrimType_BinaryCrossEntropy, BinaryCrossEntropyInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h new file mode 100644 index 00000000..18e66918 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/binary_cross_entropy_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c new file mode 100644 index 00000000..989f96dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/bn_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 6, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in = inputs[1]; + if ((inputs[0]->shape_size_ == 4 && inputs[0]->format_ != Format_NHWC) || + (in->shape_size_ == 4 && in->format_ != Format_NHWC)) { + return NNACL_FORMAT_ERROR; + } + const TensorC *scale = inputs[2]; + SetShapeTensor(outputs[0], in); + SetDataTypeFormat(outputs[0], in); + SetShapeTensor(outputs[1], scale); + SetDataTypeFormat(outputs[1], scale); + SetShapeTensor(outputs[2], scale); + SetDataTypeFormat(outputs[2], scale); + return NNACL_OK; +} + +REG_INFER(BatchNormGrad, PrimType_BatchNormGrad, BnGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h new file mode 100644 index 00000000..fd607a43 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/bn_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BN_GRAD_INFER_H +#define MINDSPORE_NNACL_BN_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BN_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c new file mode 100644 index 00000000..f08a3186 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape) { + if (shape_tensor == NULL || dst_shape == NULL) { + return NNACL_ERR; + } + if (shape_size == 0) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_); + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1) { + if (input_shape0_size < input_shape1_size) { + *ndim = input_shape1_size; + int fill_dim_num = input_shape1_size - input_shape0_size; + int j = 0; + for (int i = 0; i < input_shape1_size; i++) { + if (i < fill_dim_num) { + in_shape0[i] = 1; + } else { + in_shape0[i] = input_shape0[j++]; + } + in_shape1[i] = input_shape1[i]; + } + } else if (input_shape0_size > input_shape1_size) { + *ndim = input_shape0_size; + int fill_dim_num = input_shape0_size - input_shape1_size; + int j = 0; + for (int i = 0; i < input_shape0_size; i++) { + if (i < fill_dim_num) { + in_shape1[i] = 1; + } else { + in_shape1[i] = input_shape1[j++]; + } + in_shape0[i] = input_shape0[i]; + } + } else { + for (int i = 0; i < input_shape0_size; i++) { + in_shape1[i] = input_shape1[i]; + in_shape0[i] = input_shape0[i]; + } + } +} + +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast) { + for (int i = 0; i < ndim; i++) { + if (in_shape0[i] != in_shape1[i]) { + if (in_shape0[i] == 1) { + out_shape[i] = in_shape1[i]; + } else if (in_shape1[i] == 1) { + out_shape[i] = in_shape0[i]; + } else { + return NNACL_ERR; + } + *has_broad_cast = true; + } else { + out_shape[i] = in_shape0[i]; + } + } + return NNACL_OK; +} + +int BroadCastToShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *out_shape, bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size != 1 && inputs_size != 2) { + return NNACL_ERR; + } + if (outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int dst_shape[MAX_SHAPE_SIZE] = {0}; + int dst_shape_size; + const int *input_shape = input->shape_; + int input_shape_size = input->shape_size_; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = input_shape_size; + bool has_broad_cast = false; + if (inputs_size == 1) { + BroadcastToParameter *param = (BroadcastToParameter *)parameter; + dst_shape_size = (int)param->shape_size_; + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < dst_shape_size; i++) { + dst_shape[i] = param->shape_[i]; + } + } else { + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + dst_shape_size = NNACLGetElementNum(shape_tensor); + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ret = GetShapeByType(shape_tensor, dst_shape_size, dst_shape); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < dst_shape_size; ++i) { + if (dst_shape[i] == -1) { + dst_shape[i] = inputs[0]->shape_[i]; + } + } + } + + if (BroadCastToShape(input_shape_size, dst_shape_size, input_shape, dst_shape, &ndim, output_shape, + &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(outputs[0], output_shape, (size_t)ndim); + return NNACL_OK; +} + +REG_INFER(BroadcastTo, PrimType_BroadcastTo, BroadcastToInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h new file mode 100644 index 00000000..9d607f2d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/broadcast_to_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BROADCAST_TO_INFER_H +#define MINDSPORE_NNACL_BROADCAST_TO_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/base/broadcast_to.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1); +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BROADCAST_TO_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c new file mode 100644 index 00000000..c8d715c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cast_gather_reduce_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[C2NUM]->data_ == NULL) { + return NNACL_NULL_PTR; + } + int axis = *((int *)inputs[C2NUM]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + out_shape[1] = 1; + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(CastGatherReduceFusion, PrimType_Inner_CastGatherReduceFusion, CastGatherReduceFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h new file mode 100644 index 00000000..da9c1126 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_gather_reduce_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H +#define MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c new file mode 100644 index 00000000..035697ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cast_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + const TensorC *dst_type = inputs[1]; + if (dst_type->data_ == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = *((int *)dst_type->data_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && + input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && + input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && + input->data_type_ != kNumberTypeFloat16) { + return NNACL_INPUT_TENSOR_ERROR; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cast, PrimType_Cast, CastInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h new file mode 100644 index 00000000..530516b5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cast_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_INFER_H +#define MINDSPORE_NNACL_CAST_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CAST_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c new file mode 100644 index 00000000..39ab31c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c @@ -0,0 +1,338 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/common_infer.h" +#include +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/tensorlist_c_utils.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size) { + for (int i = 0; i < tensors_size; i++) { + TensorC *t = tensors[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size) { + for (int i = 0; i < in_size; i++) { + TensorC *t = in[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + for (int i = 0; i < out_size; i++) { + TensorC *t = out[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) { + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + dst_shape[i] = src_shape[i]; + } + *dst_shape_size = i; +} + +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size) { + if (dst_shape_size == NULL || dst_shape == NULL || src_shape == NULL) { + return false; + } + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + if (MS_UNLIKELY(src_shape[i] > (int64_t)INT32_MAX || src_shape[i] < (int64_t)INT32_MIN)) { + return false; + } + dst_shape[i] = (int32_t)(src_shape[i]); + } + *dst_shape_size = i; + return true; +} + +void ShapePush(int *shape, size_t *shape_size, int value) { + if (*shape_size >= MAX_SHAPE_SIZE) { + return; + } + shape[*shape_size] = value; + *shape_size = *shape_size + 1; +} + +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size) { + if (tensor->data_ == NULL || result == NULL || result_size == NULL) { + return NNACL_ERR; + } + if (tensor->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ele_num = NNACLGetElementNum(tensor); + if (ele_num <= 0) { + return NNACL_ERR; + } + *result_size = (size_t)ele_num; + if (tensor->data_type_ == kNumberTypeInt || tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + result[i] = data[i]; + } + } else if (tensor->data_type_ == kNumberTypeInt64) { + int64_t *data = (int64_t *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + if (data[i] >= INT32_MAX) { + return NNACL_ERR; + } + result[i] = (int32_t)data[i]; + } + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ShapeInsert(int *shape, size_t *shape_size, int index, int value) { + if (index < 0 || index > *shape_size) { + return NNACL_ERR; + } + if (*shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = *shape_size; i > index; i--) { + shape[i] = shape[i - 1]; + } + shape[index] = value; + *shape_size = *shape_size + 1; + return NNACL_OK; +} + +int ShapeErase(int *shape, size_t *shape_size, int index) { + if (index < 0 || index >= *shape_size) { + return NNACL_ERR; + } + + for (int i = index; i < *shape_size - 1; i++) { + shape[i] = shape[i + 1]; + } + *shape_size = *shape_size - 1; + return NNACL_OK; +} + +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size) { + if (shape0_size != shape1_size) { + return false; + } + for (size_t i = 0; i < shape0_size; i++) { + if (shape0[i] != shape1[i]) { + return false; + } + } + return true; +} + +void iswap(int *a, int *b) { + int tmp = *a; + *a = *b; + *b = tmp; +} + +int imin(int a, int b) { return a > b ? b : a; } + +int imax(int a, int b) { return a < b ? b : a; } + +// input == output completely refer to +// 1. zeros_like +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(inputs[0]->shape_size_ == inputs[1]->shape_size_, NNACL_ERR); + for (int i = 0; i < inputs[0]->shape_size_; i++) { + if (inputs[0]->shape_[i] != inputs[1]->shape_[i]) { + return NNACL_ERR; + } + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs[0]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + if (input_shape_size == 0) { + return NNACL_ERR; + } + input_shape_size--; + SetShapeArray(output, input_shape, input_shape_size); + return NNACL_OK; +} + +bool InferFlag(const TensorC *const *inputs, size_t inputs_size) { + if (inputs == NULL) { + return false; + } + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return false; + } + if (inputs[i]->data_type_ == kObjectTypeTensorType) { + if (InferFlagTensorList((TensorC *)inputs[i]) == false) { + return false; + } + } else { + for (size_t j = 0; j < inputs[i]->shape_size_; ++j) { + if (inputs[i]->shape_[j] < 0) { + return false; + } + } + } + } + return true; +} + +REG_INFER(Abs, PrimType_Abs, CommonInferShape) +REG_INFER(AbsGrad, PrimType_AbsGrad, CommonGradInferShape) +REG_INFER(Activation, PrimType_Activation, CommonInferShape) +REG_INFER(BatchNorm, PrimType_BatchNorm, CommonInferShape) +REG_INFER(BinaryCrossEntropyGrad, PrimType_BinaryCrossEntropyGrad, CommonInferShape) +REG_INFER(Ceil, PrimType_Ceil, CommonInferShape) +REG_INFER(Clip, PrimType_Clip, CommonInferShape) +REG_INFER(Cos, PrimType_Cos, CommonInferShape) +REG_INFER(Depend, PrimType_Depend, CommonInferShape) +REG_INFER(Elu, PrimType_Elu, CommonInferShape) +REG_INFER(Erf, PrimType_Erf, CommonInferShape) +REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape) +REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape) +REG_INFER(Floor, PrimType_Floor, CommonInferShapeWithOneInput) +REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape) +REG_INFER(Log, PrimType_Log, CommonInferShape) +REG_INFER(Log1p, PrimType_Log1p, CommonInferShape) +REG_INFER(LogGrad, PrimType_LogGrad, CommonGradInferShape) +REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape) +REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC) +REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape) +REG_INFER(Neg, PrimType_Neg, CommonInferShape) +REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape) +REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape) +REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape) +REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape) +REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape) +REG_INFER(ReverseSequence, PrimType_ReverseSequence, CommonInferShape) +REG_INFER(Reverse, PrimType_ReverseV2, CommonInferShape) +REG_INFER(Round, PrimType_Round, CommonInferShape) +REG_INFER(Rsqrt, PrimType_Rsqrt, CommonInferShape) +REG_INFER(Scale, PrimType_ScaleFusion, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogits, PrimType_SigmoidCrossEntropyWithLogits, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogitsGrad, PrimType_SigmoidCrossEntropyWithLogitsGrad, CommonInferShape) +REG_INFER(Sin, PrimType_Sin, CommonInferShape) +REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape) +REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape) +REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape) +REG_INFER(SqrtGrad, PrimType_SqrtGrad, CommonInferShape) +REG_INFER(Square, PrimType_Square, CommonInferShape) +REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape) +REG_INFER(ScatterElements, PrimType_ScatterElements, CommonInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h new file mode 100644 index 00000000..24ddfb3f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h @@ -0,0 +1,94 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_H_ +#define MINDSPORE_NNACL_COMMON_H_ + +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size); +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size); + +#define EPSILON_VALUE 1e-6 + +enum NNACLLshProjectionType { + LshProjectionType_UNKNOWN = 0, + LshProjectionType_SPARSE = 1, + LshProjectionType_DENSE = 2, + LshProjectionType_MIN = LshProjectionType_UNKNOWN, + LshProjectionType_MAX = LshProjectionType_DENSE +}; + +typedef struct VectorC { + int *data_; + size_t size_; + size_t max_size_; + size_t per_malloc_size_; +} VectorC; + +#ifdef __cplusplus +extern "C" { +#endif + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj); +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj); +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj); +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +void SetDataTypeFormat(TensorC *dst, const TensorC *src); + +void SetShapeTensor(TensorC *dst, const TensorC *src); +void SetShapeArray(TensorC *dst, const int *src, size_t src_size); +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size); +void ShapePush(int *shape, size_t *shape_size, int value); +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size); +int ShapeInsert(int *shape, size_t *shape_size, int index, int value); +int ShapeErase(int *shape, size_t *shape_size, int index); +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size); + +void iswap(int *a, int *b); + +int imin(int a, int b); +int imax(int a, int b); + +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +bool InferFlag(const TensorC *const *inputs, size_t inputs_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_COMMON__H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c new file mode 100644 index 00000000..94f838a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/concat_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DataTypeJudge(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + ConcatParameter *param = (ConcatParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)input0_shape_size : param->axis_; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + if (!ShapeEqual(input0_shape_without_axis, input0_shape_without_axis_size, shape_tmp, shape_tmp_size)) { + return NNACL_ERR; + } + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Concat, PrimType_Concat, ConcatInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h new file mode 100644 index 00000000..c743f3cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/concat_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONCAT_INFER_H +#define MINDSPORE_NNACL_CONCAT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c new file mode 100644 index 00000000..e595b041 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.c @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + ConstantOfShapeParameter *param = (ConstantOfShapeParameter *)parameter; + out_tensor->data_type_ = (TypeIdC)(param->data_type_); + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size) || in_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int size = NNACLGetElementNum(in_tensor); + if (size < 0 || size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE]; + int out_shape_size = size; + switch (in_tensor->data_type_) { + case kNumberTypeInt32: { + int32_t *in_data = (int32_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + case kNumberTypeInt64: { + int64_t *in_data = (int64_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + default: + return NNACL_INFER_INVALID; + } + + SetShapeArray(out_tensor, out_shape, (size_t)out_shape_size); + return NNACL_OK; +} + +REG_INFER(ConstantOfShape, PrimType_ConstantOfShape, ConstantOfShapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h new file mode 100644 index 00000000..b7ccc57c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/constant_of_shape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H +#define MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c new file mode 100644 index 00000000..de4b0cdb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { +#ifdef Debug + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } +#endif + + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + output->data_type_ = param->data_type_; + SetShapeArray(output, param->element_shape_, (size_t)param->element_shape_size_); + + return NNACL_OK; +} + +REG_INFER(TensorArray, PrimType_TensorArray, TensorArrayInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h new file mode 100644 index 00000000..6bdb26a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c new file mode 100644 index 00000000..fbbf22a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_read_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { prim, handle, index } -> node + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 1, NNACL_ERR); + NNACL_CHECK_TRUE_RET(outputs_size >= 1, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + output->data_type_ = handle->data_type_; + SetShapeArray(output, handle->shape_, handle->shape_size_); + + return NNACL_OK; +} + +REG_INFER(TensorArrayRead, PrimType_TensorArrayRead, TensorArrayReadInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h new file mode 100644 index 00000000..a9fa6fa9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_read_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c new file mode 100644 index 00000000..460bee83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensor_array_write_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_array_parameter.h" + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { handle, index, value, flow_in } -> empty + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 3, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *value = (TensorC *)inputs[2]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + if (handle->shape_size_ != value->shape_size_) { + return NNACL_INFER_INVALID; + } + + for (size_t i = 0; i < handle->shape_size_; ++i) { + if (handle->shape_[i] != value->shape_[i]) { + return NNACL_INFER_INVALID; + } + } + + return NNACL_OK; +} + +REG_INFER(TensorArrayWrite, PrimType_TensorArrayWrite, TensorArrayWriteInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h new file mode 100644 index 00000000..224c44e7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensor_array_write_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c new file mode 100644 index 00000000..c812e2d3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + const TensorC *input0 = inputs[0]; + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = input0->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input0->shape_size_ < 1) { + return NNACL_ERR; + } + int dim0 = input0->shape_[0]; + if (dim0 < 0) { + return NNACL_ERR; + } + const TensorC *input1 = inputs[1]; + if (input1->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(input1->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_ptr); + vvector tensor_shape; + tensor_shape.size_ = (size_t)(dim0); + tensor_shape.shape_ = (int **)malloc(tensor_shape.size_ * sizeof(int *)); + if (tensor_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tensor_shape.shape_size_ = (int *)malloc(tensor_shape.size_ * sizeof(int)); + if (tensor_shape.shape_size_ == NULL) { + free(tensor_shape.shape_); + return NNACL_NULL_PTR; + } + for (int i = 0; i < dim0; i++) { + tensor_shape.shape_[i] = (int *)(input0->shape_ + 1); + tensor_shape.shape_size_[i] = (int)(input0->shape_size_) - 1; + } + + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input1)); + output->element_num_ = (size_t)(dim0); + int ret = MallocTensorListData(output, input0->data_type_, &tensor_shape); + if (ret != NNACL_OK) { + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_ERR; + } + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListFromTensor, PrimType_TensorListFromTensor, TensorListFromTensorInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h new file mode 100644 index 00000000..884e0422 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_fromtensor_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c new file mode 100644 index 00000000..227043f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_ERR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + if (!InferFlag(inputs, inputs_size) || input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + int index = ((int *)(get_index->data_))[0]; + if (index < 0 || index > ((int)(input0->element_num_ - 1))) { + return NNACL_ERR; + } + TensorC *tensor_index = input0->tensors_[index]; + NNACL_CHECK_NULL_RETURN_ERR(tensor_index); + + if (tensor_index->data_type_ != kTypeUnknown) { + output->data_type_ = tensor_index->data_type_; + } else { + output->data_type_ = input0->tensors_data_type_; + } + output->format_ = input0->tensors_[index]->format_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (tensor_index->data_type_ != kTypeUnknown) { + ShapeSet(output->shape_, &(output->shape_size_), tensor_index->shape_, tensor_index->shape_size_); + } else { + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input2); + NNACL_CHECK_NULL_RETURN_ERR(input2->data_); + int *ele_shape_data = (int *)(input2->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_data); + int element_shape[MAX_SHAPE_SIZE] = {0}; + size_t element_shape_size = 0; + for (int i = 0; i < NNACLGetElementNum(input2); ++i) { + ShapePush(element_shape, &element_shape_size, ele_shape_data[i]); + } + int status = + TensorListMergeShape(element_shape, &element_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *input = input0->tensors_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input); + if (input->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(element_shape, &element_shape_size, input->shape_, input->shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + } + } + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { // the pre is the same judge condition + return NNACL_ERR; + } + + SetShapeArray(output, element_shape, element_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(TensorListGetItem, PrimType_TensorListGetItem, TensorListGetItemInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h new file mode 100644 index 00000000..19d4804a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_getitem_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/tensorlist_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c new file mode 100644 index 00000000..544041a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_parameter.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListParameter *reserve_param = (TensorListParameter *)parameter; + const TensorC *input0 = inputs[0]; + int ele_shape_type = input0->data_type_; + if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = reserve_param->element_dtype_; + + if (input0->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int *ele_shape_ptr = (int *)(input0->data_); + + const TensorC *input1 = inputs[1]; + int num_ele_type = input1->data_type_; + if (num_ele_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + if (input1->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input1) != 1) { + return NNACL_ERR; + } + int num_elements = ((int *)(input1->data_))[0]; + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input0)); + output->element_num_ = (size_t)(num_elements); + + vvector tmp_shape; + tmp_shape.size_ = (size_t)(num_elements); + tmp_shape.shape_ = (int **)malloc(tmp_shape.size_ * sizeof(int *)); + if (tmp_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tmp_shape.shape_size_ = (int *)malloc(tmp_shape.size_ * sizeof(int)); + if (tmp_shape.shape_size_ == NULL) { + free(tmp_shape.shape_); + return NNACL_NULL_PTR; + } + + for (size_t i = 0; i < num_elements; ++i) { + tmp_shape.shape_size_[i] = output->element_shape_size_; + tmp_shape.shape_[i] = output->element_shape_; + } + int ret = MallocTensorListData(output, reserve_param->element_dtype_, &tmp_shape); + free(tmp_shape.shape_size_); + free(tmp_shape.shape_); + return ret; +} + +REG_INFER(TensorListReserve, PrimType_TensorListReserve, TensorListReserveInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h new file mode 100644 index 00000000..1a753f36 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_reserve_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c new file mode 100644 index 00000000..4b2c77ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int PreJudge(const TensorC *get_index, TensorListC *input0, const TensorC *value_tensor) { + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + if (get_index->data_type_ != kNumberTypeInt && get_index->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + if (get_index->data_ == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + const TensorC *value_tensor = inputs[2]; + TensorListC *output0 = (TensorListC *)(outputs[0]); + output0->data_type_ = input0->data_type_; + output0->format_ = input0->format_; + output0->tensors_data_type_ = value_tensor->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int judge_ret = PreJudge(get_index, input0, value_tensor); + if (judge_ret != NNACL_OK) { + return judge_ret; + } + + int index = ((int *)(get_index->data_))[0]; + output0->max_elements_num_ = input0->max_elements_num_; + + if (input0->element_num_ == 0 && input0->element_shape_size_ == 0 && index == 0) { + ShapeSet(input0->element_shape_, &(input0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + } else { + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), input0->element_shape_, + input0->element_shape_size_); + } + + vvector out_shape; + out_shape.size_ = 0; + out_shape.shape_ = (int **)malloc((input0->element_num_ + 1) * sizeof(int *)); + if (out_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + out_shape.shape_size_ = (int *)malloc((input0->element_num_ + 1) * sizeof(int)); + if (out_shape.shape_size_ == NULL) { + free(out_shape.shape_); + return NNACL_NULL_PTR; + } + + if (index == 0 && input0->element_num_ == 0) { // uninitialized tensorlist + out_shape.shape_[out_shape.size_] = (int *)(value_tensor->shape_); + out_shape.shape_size_[out_shape.size_] = value_tensor->shape_size_; + out_shape.size_++; + output0->element_num_ = 1; + } else { + output0->element_num_ = input0->element_num_; + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *src_ptr = input0->tensors_[i]; + if (src_ptr == NULL) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_NULL_PTR; + } + if (src_ptr->data_type_ != kTypeUnknown) { + out_shape.shape_[out_shape.size_] = src_ptr->shape_; + out_shape.shape_size_[out_shape.size_] = (int)(src_ptr->shape_size_); + out_shape.size_++; + } else { + out_shape.shape_[out_shape.size_] = NULL; + out_shape.shape_size_[out_shape.size_] = 0; + out_shape.size_++; + } + } + } + + if (input0->tensors_data_type_ == kTypeUnknown) { + input0->tensors_data_type_ = value_tensor->data_type_; + } + + out_shape.shape_[index] = (int *)(value_tensor->shape_); + out_shape.shape_size_[index] = (int)value_tensor->shape_size_; + int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape); + if (ret != NNACL_OK) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_ERR; + } + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListSetItem, PrimType_TensorListSetItem, TensorListSetItemInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h new file mode 100644 index 00000000..066860e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_setitem_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c new file mode 100644 index 00000000..bf15df5e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.c @@ -0,0 +1,96 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + output->data_type_ = input0->tensors_data_type_; + output->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + const TensorC *ele_shape = inputs[1]; // element shape + if (ele_shape->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(ele_shape->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (ele_shape_ptr[0] == -1) { + if (input0->element_shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (size_t i = 0; i < input0->element_shape_size_; i++) { + ShapePush(output_shape, &output_shape_size, input0->element_shape_[i]); + } + } else { + int ele_shape_num = NNACLGetElementNum(ele_shape); + if (ele_shape_num > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < ele_shape_num; ++i) { + ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]); + } + } + + int status = + TensorListMergeShape(output_shape, &output_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(output_shape, output_shape_size)) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(input0->element_shape_, input0->element_shape_size_)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *tensor_ele = input0->tensors_[i]; + if (tensor_ele->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(output_shape, &output_shape_size, tensor_ele->shape_, tensor_ele->shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + } + } + } + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(TensorListStack, PrimType_TensorListStack, TensorListStackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h new file mode 100644 index 00000000..d9c2b84a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/control/tensorlist_stack_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c new file mode 100644 index 00000000..44e73686 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + if (inputs[FIRST_INPUT]->format_ != Format_NHWC || inputs[SECOND_INPUT]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[FIRST_INPUT], inputs[FIRST_INPUT]); + + if (inputs[THIRD_INPUT]->shape_size_ < DIMENSION_1D || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + if (inputs[THIRD_INPUT]->shape_[kNCHW_N] < 0) { + return NNACL_ERR; + } + size_t filter_shape_size = (size_t)(inputs[THIRD_INPUT]->shape_[kNCHW_N]); + if (filter_shape_size != DIMENSION_4D) { + return NNACL_ERR; + } + + int filter_shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < filter_shape_size; i++) { + filter_shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(filter_shape, inputs[THIRD_INPUT]->data_, filter_shape_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(outputs[0], filter_shape, filter_shape_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropFilterFusion, PrimType_Conv2DBackpropFilterFusion, Conv2dGradFilterInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h new file mode 100644 index 00000000..4939068c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_filter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c new file mode 100644 index 00000000..62311d21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.c @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0 == NULL || out == NULL) { + return NNACL_NULL_PTR; + } + if (in0->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(out, in0); + + if (inputs[THIRD_INPUT]->shape_size_ < 1 || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + size_t data_size = (size_t)inputs[2]->shape_[0]; + if (data_size != 4) { + return NNACL_ERR; + } + + int shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[4] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < data_size; i++) { + shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(shape, inputs[THIRD_INPUT]->data_, data_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(out, shape, data_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropInputFusion, PrimType_Conv2DBackpropInputFusion, Conv2dGradInputInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h new file mode 100644 index 00000000..58349fed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c new file mode 100644 index 00000000..42ff7d22 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.c @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/conv2d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvInferShape(int input_h, int input_w, int *output_h, int *output_w, ConvParameter *param) { + int kernel_w = param->kernel_w_; + int kernel_h = param->kernel_h_; + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + + if (stride_w == 0 || stride_h == 0) { + return NNACL_PARAM_INVALID; + } + if (INT_MUL_OVERFLOW(kernel_h, dilate_h) || INT_MUL_OVERFLOW(kernel_w, dilate_w)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + if (param->pad_mode_ == Pad_same) { // maybe error + *output_w = ceil((float)(input_w) / (float)(stride_w)); + *output_h = ceil((float)(input_h) / (float)(stride_h)); + int pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); + int pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else if (param->pad_mode_ == Pad_valid) { + *output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - ((float)(kernel_w)-1) * (float)(dilate_w)) / + (float)(stride_w)); + *output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - ((float)(kernel_h)-1) * (float)(dilate_h)) / + (float)(stride_h)); + } else { + int kernel_width = (kernel_w - 1) * dilate_w + 1; + int kernel_height = (kernel_h - 1) * dilate_h + 1; + *output_w = ((input_w) + param->pad_l_ + param->pad_r_ - kernel_width) / stride_w + 1; + *output_h = ((input_h) + param->pad_u_ + param->pad_d_ - kernel_height) / stride_h + 1; + } + + if (param->kernel_h_ > input_h + param->pad_u_ + param->pad_d_ || + param->kernel_w_ > input_w + param->pad_l_ + param->pad_r_) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +static const int MAX_CONV_KERNEL_DIM = 10000; // One big value that should not be adopted as the conv kernel dimension. + +int CheckConvAttr(const int input_c, const TensorC *weight_tensor, const ConvParameter *param) { + // common conv: input_c == weight_tensor->shape_[3] + // conv depthwise: input_c == 1 + // group conv: input_c / group == weight_tensor->shape_[3] + NNACL_CHECK_FALSE(param->group_ == 0, NNACL_PARAM_INVALID); + if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) { + return NNACL_PARAM_INVALID; + } + + // common conv: group == 1 + // conv depthwise: group == input_c == output_c + // group conv: group == input_c / weight_tensor->shape_[3] + NNACL_CHECK_FALSE(weight_tensor->shape_[3] == 0, NNACL_PARAM_INVALID); + if (param->group_ != 1 && param->group_ != input_c && param->group_ != (input_c / weight_tensor->shape_[3])) { + return NNACL_PARAM_INVALID; + } + if (param->stride_h_ <= 0 || param->stride_w_ <= 0) { + return NNACL_PARAM_INVALID; + } + + if ((param->kernel_h_ >= MAX_CONV_KERNEL_DIM) || (param->kernel_w_ >= MAX_CONV_KERNEL_DIM)) { + return NNACL_PARAM_INVALID; + } + + NNACL_CHECK_TRUE_RET(param->kernel_h_ == weight_tensor->shape_[1], NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(param->kernel_w_ == weight_tensor->shape_[2], NNACL_PARAM_INVALID); + return NNACL_OK; +} + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_tensor = inputs[0]; + if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC && + input_tensor->format_ != Format_NC4HW4 && input_tensor->format_ != Format_NC8HW8) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight_tensor = inputs[1]; + if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *out_tensor = outputs[0]; + if (out_tensor->format_ != Format_NC4HW4) { + out_tensor->format_ = input_tensor->format_; + } + out_tensor->data_type_ = input_tensor->data_type_; + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight_tensor->shape_[0]; + } + param->output_channel_ = weight_tensor->shape_[0]; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1]; + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2]; + + if (input_tensor->shape_size_ == 0) { + return NNACL_INFER_INVALID; + } + + int ret = CheckConvAttr(NNACLGetChannel(input_tensor), weight_tensor, param); + if (ret != NNACL_OK) { + return ret; + } + + int output_w = 0, output_h = 0; + ret = ConvInferShape(NNACLGetHeight(input_tensor), NNACLGetWidth(input_tensor), &output_h, &output_w, param); + if (ret != NNACL_OK) { + return ret; + } + + out_tensor->shape_size_ = input_tensor->shape_size_; + NNACLSetBatch(out_tensor, NNACLGetBatch(input_tensor)); + NNACLSetChannel(out_tensor, NNACLGetBatch(weight_tensor)); + output_h = output_h >= 0 ? output_h : 1; + NNACLSetHeight(out_tensor, output_h); + output_w = output_w >= 0 ? output_w : 1; + NNACLSetWidth(out_tensor, output_w); + + param->input_batch_ = NNACLGetBatch(input_tensor); + param->input_h_ = NNACLGetHeight(input_tensor); + param->input_w_ = NNACLGetWidth(input_tensor); + param->input_channel_ = NNACLGetChannel(input_tensor); + param->output_batch_ = NNACLGetBatch(out_tensor); + param->output_h_ = NNACLGetHeight(out_tensor); + param->output_w_ = NNACLGetWidth(out_tensor); + param->output_channel_ = NNACLGetChannel(out_tensor); + param->out_format_ = out_tensor->format_; + return NNACL_OK; +} + +REG_INFER(Adder, PrimType_AdderFusion, Conv2dInferShape) +REG_INFER(Conv2D, PrimType_Conv2DFusion, Conv2dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h new file mode 100644 index 00000000..c38b83b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_INFER_H +#define MINDSPORE_NNACL_CONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c new file mode 100644 index 00000000..fbed116e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.c @@ -0,0 +1,27 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/conv3d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // The InferShape of Conv3D is not implemented here, it just prevents the InferShape process from being interrupted + // and makes the nodes shape are {}. + return NNACL_OK; +} + +REG_INFER(Conv3D, PrimType_Inner_Conv3D, Conv3dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h new file mode 100644 index 00000000..9c4ee48a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/conv3d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV3D_INFER_H +#define MINDSPORE_NNACL_CONV3D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV3D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c new file mode 100644 index 00000000..cd8e2f45 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.c @@ -0,0 +1,69 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (NNACLGetBatch(input) == 0) { + ShapePush(output_shape, &output_shape_size, 0); + } else if (inputs[1]->data_ != NULL) { + const TensorC *boxes_tensor = inputs[1]; + if (boxes_tensor->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, boxes_tensor->shape_[0]); + } else { + return NNACL_INFER_INVALID; + } + + const TensorC *shape_tensor = inputs[3]; + int32_t *data = (int32_t *)(shape_tensor->data_); + if (data == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(shape_tensor) < 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(CropAndResize, PrimType_CropAndResize, CropAndResizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h new file mode 100644 index 00000000..7571b88e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_and_resize_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H +#define MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c new file mode 100644 index 00000000..da920180 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + size_t input_shape_size = inputs[0]->shape_size_; + CropParameter *param = (CropParameter *)parameter; + int64_t axis = param->axis_ < 0 ? param->axis_ + (int64_t)input_shape_size : param->axis_; + if (axis < 0 || axis >= (int64_t)input_shape_size) { + return NNACL_ERR; + } + + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[1]); + return NNACL_OK; +} + +REG_INFER(Crop, PrimType_Crop, CropInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h new file mode 100644 index 00000000..aab29737 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/crop_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_INFER_H +#define MINDSPORE_NNACL_CROP_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c new file mode 100644 index 00000000..ebd1d0e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cumsum, PrimType_CumSum, CumsumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h new file mode 100644 index 00000000..877ae308 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/cumsum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUMSUM_INFER_H +#define MINDSPORE_NNACL_CUMSUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUMSUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c new file mode 100644 index 00000000..2963e460 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_gru_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + const TensorC *weight_in = inputs[1]; + if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM; + return NNACL_OK; +} + +REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h new file mode 100644 index 00000000..d154a971 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_gru_infer.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c new file mode 100644 index 00000000..740eff84 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_is_inf_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h new file mode 100644 index 00000000..87b8731e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_is_inf_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H +#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c new file mode 100644 index 00000000..302dec2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_masked_fill_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h new file mode 100644 index 00000000..844f5e0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_masked_fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c new file mode 100644 index 00000000..da6b55e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/custom_tensor_scatter_max_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h new file mode 100644 index 00000000..f19cccb7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/custom_tensor_scatter_max_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H +#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c new file mode 100644 index 00000000..6257e58e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/infer/decoder_layer_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C16NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(DecoderLayer, PrimType_Inner_DecoderLayer, DecoderLayerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h new file mode 100644 index 00000000..2b894fd1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/decoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c new file mode 100644 index 00000000..01b243fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.c @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/deconv2d_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight = inputs[1]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + output->data_type_ = input->data_type_; + + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight->shape_[0]; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t input_h = NNACLGetHeight(input); + int32_t input_w = NNACLGetWidth(input); + + int32_t output_n = NNACLGetBatch(input); + int32_t output_h = 0; + int32_t output_w = 0; + int32_t output_c = NNACLGetChannel(weight); + NNACL_CHECK_TRUE_RET(NNACLGetChannel(input) == NNACLGetBatch(weight), NNACL_ERR); + if (param->group_ == NNACLGetChannel(input) && 1 == NNACLGetChannel(weight)) { + output_c = NNACLGetBatch(weight); /* depthwise */ + } + + int kernel_w = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(weight); + int kernel_h = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(weight); + NNACL_CHECK_FALSE(kernel_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(kernel_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, kernel_w), NNACL_ERR); + + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + NNACL_CHECK_FALSE(stride_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(stride_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_h, stride_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_w, stride_w), NNACL_ERR); + + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, dilate_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_w, dilate_w), NNACL_ERR); + + int pad_mode = param->pad_mode_; + if (pad_mode == Pad_pad) { + output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - param->pad_u_ - param->pad_d_; + output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - param->pad_l_ - param->pad_r_; + } else if (pad_mode == Pad_same) { + output_h = input_h * stride_h; + output_w = input_w * stride_w; + } else if (pad_mode == Pad_valid) { + output_h = (input_h - 1) * stride_h + kernel_h; + output_w = (input_w - 1) * stride_w + kernel_w; + } else { + return NNACL_ERR; + } + + output_h += param->output_padding_h_; + output_w += param->output_padding_w_; + + output->shape_size_ = 4; + output->shape_[0] = output_n; + output->shape_[1] = output_h; + output->shape_[2] = output_w; + output->shape_[3] = output_c; + + if (pad_mode == Pad_same) { + param->pad_u_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2; + param->pad_l_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2; + } else if (pad_mode == Pad_valid) { + param->pad_u_ = 0; + param->pad_l_ = 0; + } + + const int *in_shape = input->shape_; + param->input_batch_ = in_shape[0]; + param->input_h_ = in_shape[1]; + param->input_w_ = in_shape[2]; + param->input_channel_ = in_shape[3]; + param->output_batch_ = output_n; + param->output_h_ = output_h; + param->output_w_ = output_w; + param->output_channel_ = output_c; + param->kernel_h_ = kernel_h; + param->kernel_w_ = kernel_w; + return NNACL_OK; +} + +REG_INFER(Conv2dTranspose, PrimType_Conv2dTransposeFusion, Deconv2dInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h new file mode 100644 index 00000000..a2c713b1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/deconv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DECONV2D_INFER_H +#define MINDSPORE_NNACL_DECONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DECONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c new file mode 100644 index 00000000..bdd1eb1f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + DepthToSpaceParameter *param = (DepthToSpaceParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_PARAM_INVALID; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + + int32_t block_size = param->block_size_; + if (INT_MUL_OVERFLOW(block_size, block_size)) { + return NNACL_PARAM_INVALID; + } + if (block_size == 0 || input_shape[kNHWC_C] % (block_size * block_size) != 0 || input_shape[kNHWC_C] == 0) { + return NNACL_PARAM_INVALID; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N]; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_size; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_size; + output_shape[kNHWC_C] = input_shape[kNHWC_C] / (block_size * block_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(DepthToSpace, PrimType_DepthToSpace, DepthToSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h new file mode 100644 index 00000000..6b67618a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depth_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H +#define MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/depth_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c new file mode 100644 index 00000000..79b88730 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/tensor_c_utils.h" + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ConvParameter *param = (ConvParameter *)parameter; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + int input_channel = input->shape_[3]; + int output_w = 0, output_h = 0; + param->input_channel_ = input_channel; + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(inputs[kWeightIndex]); + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(inputs[kWeightIndex]); + if (param->pad_mode_ == Pad_same) { + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->kernel_h_ - 1) * param->dilation_h_ + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->kernel_w_ - 1) * param->dilation_w_ + 1 - input_w); + if (pad_h_all > 0) { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all > 0) { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else { + output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - + ((float)(param->kernel_h_) - 1) * (float)(param->dilation_h_)) / + (float)(param->stride_h_)); + output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - + ((float)(param->kernel_w_) - 1) * (float)(param->dilation_w_)) / + (float)(param->stride_w_)); + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + out_shape[1] = output_h; + out_shape[2] = output_w; + if (param->channel_multiplie_ != 1) { + return NNACL_ERR; + } + out_shape[3] = input_channel; // in_channel * out_channel + SetShapeArray(output, out_shape, out_shape_size); + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h new file mode 100644 index 00000000..6230491e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/depthwise_conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H +#define MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c new file mode 100644 index 00000000..c3e43e44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *boxes = inputs[0]; + const TensorC *scores = inputs[1]; + const TensorC *anchors = inputs[2]; + if (boxes->shape_size_ < 2 || scores->shape_size_ < 3 || anchors->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + DetectionPostProcessParameter *param = (DetectionPostProcessParameter *)parameter; + if (scores->shape_[2] < param->num_classes_) { + return NNACL_ERR; + } + if (scores->shape_[2] - param->num_classes_ > 1) { + return NNACL_ERR; + } + if (boxes->shape_[1] != scores->shape_[1]) { + return NNACL_ERR; + } + if (boxes->shape_[1] != anchors->shape_[0]) { + return NNACL_ERR; + } + + TensorC *detected_boxes = outputs[0]; + TensorC *detected_classes = outputs[1]; + TensorC *detected_scores = outputs[2]; + TensorC *num_det = outputs[3]; + + detected_boxes->format_ = boxes->format_; + detected_boxes->data_type_ = kNumberTypeFloat32; + detected_classes->format_ = boxes->format_; + detected_classes->data_type_ = kNumberTypeFloat32; + detected_scores->format_ = boxes->format_; + detected_scores->data_type_ = kNumberTypeFloat32; + num_det->format_ = boxes->format_; + num_det->data_type_ = kNumberTypeFloat32; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const int max_detections = param->max_detections_; + const int max_classes_per_detection = param->max_classes_per_detection_; + const int num_detected_boxes = (int)(max_detections * max_classes_per_detection); + detected_boxes->shape_size_ = 3; + detected_boxes->shape_[0] = 1; + detected_boxes->shape_[1] = num_detected_boxes; + detected_boxes->shape_[2] = 4; + detected_classes->shape_size_ = 2; + detected_classes->shape_[0] = 1; + detected_classes->shape_[1] = num_detected_boxes; + detected_scores->shape_size_ = 2; + detected_scores->shape_[0] = 1; + detected_scores->shape_[1] = num_detected_boxes; + num_det->shape_size_ = 1; + num_det->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(DetectionPostProcess, PrimType_DetectionPostProcess, DetectionPostProcessInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h new file mode 100644 index 00000000..4c40cbe7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/detection_post_process_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H +#define MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/detection_post_process_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c new file mode 100644 index 00000000..dc07820d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DropoutGrad, PrimType_DropoutGrad, DropoutGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h new file mode 100644 index 00000000..f3ef8751 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H +#define MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c new file mode 100644 index 00000000..96e4d263 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dropout_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + if (outputs_size > 1) { + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input); + SetShapeTensor(output1, input); + } + return NNACL_OK; +} + +REG_INFER(Dropout, PrimType_Dropout, DropoutInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h new file mode 100644 index 00000000..73dae73e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dropout_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_INFER_H +#define MINDSPORE_NNACL_DROPOUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c new file mode 100644 index 00000000..66022efc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/dynamic_quant_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/dynamic_quant_parameter.h" + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + DynamicQuantParameter *param = (DynamicQuantParameter *)parameter; + output->data_type_ = param->dst_type_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DynamicQuant, PrimType_DynamicQuant, DynamicQuantInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h new file mode 100644 index 00000000..5ede6f2a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/dynamic_quant_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H +#define MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c new file mode 100644 index 00000000..e3769437 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *params_ = inputs[0]; + const TensorC *ids = inputs[inputs_size - 1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, params_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (params_->shape_size_ > MAX_SHAPE_SIZE || ids->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_size = 0; + ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_); + int erase_ret = ShapeErase(embedding_shape, &embedding_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_); + for (size_t i = 0; i < embedding_shape_size; ++i) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, embedding_shape[i]); + } + for (size_t i = 1; i < inputs_size - 1; ++i) { + if (inputs[i]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape_t[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_t_size = 0; + ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_); + erase_ret = ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + bool t_equal = ShapeEqual(embedding_shape_t, embedding_shape_t_size, embedding_shape, embedding_shape_size); + if (!t_equal) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(EmbeddingLookup, PrimType_EmbeddingLookupFusion, EmbeddingLookupInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h new file mode 100644 index 00000000..91715e31 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/embedding_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H +#define MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c new file mode 100644 index 00000000..54d72144 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/infer/encoder_layer_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C9NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(EncoderLayer, PrimType_Inner_EncoderLayer, EncoderLayerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h new file mode 100644 index 00000000..1c156b35 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/encoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c new file mode 100644 index 00000000..cb316584 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/expand_dims_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size < C2NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs[1]->data_ == NULL) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (NNACLGetElementNum(inputs[1]) < 1) { + return NNACL_ERR; + } + int dim = ((int32_t *)(inputs[1]->data_))[0]; + if (dim < 0) { + dim += (int)(input->shape_size_) + 1; + } + if (dim > (int)(input->shape_size_)) { + return NNACL_INPUT_TENSOR_ERROR; + } + + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int ret = ShapeInsert(output->shape_, &(output->shape_size_), dim, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ExpandDims, PrimType_ExpandDims, ExpandDimsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h new file mode 100644 index 00000000..db53049d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/expand_dims_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EXPAND_DIMS_INFER_H +#define MINDSPORE_NNACL_EXPAND_DIMS_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EXPAND_DIMS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c new file mode 100644 index 00000000..c2865d38 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fft_imag_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftImag, PrimType_FftImag, FftImagInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h new file mode 100644 index 00000000..44f5b6f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_imag_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_IMAG_INFER_H +#define MINDSPORE_NNACL_FFT_IMAG_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_IMAG_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c new file mode 100644 index 00000000..a1c3ccc3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fft_real_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftReal, PrimType_FftReal, FftRealInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h new file mode 100644 index 00000000..0e233c68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fft_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_REAL_INFER_H +#define MINDSPORE_NNACL_FFT_REAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_REAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c new file mode 100644 index 00000000..6b47d2a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[1]; + if (dst_shape_tensor->data_type_ != kNumberTypeInt && dst_shape_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Fill, PrimType_Fill, FillInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h new file mode 100644 index 00000000..cfe46b02 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILL_INFER_H +#define MINDSPORE_NNACL_FILL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c new file mode 100644 index 00000000..b2816757 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fillv2_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[0]; + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(FillV2, PrimType_FillV2, FillV2InferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h new file mode 100644 index 00000000..00d45bd5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fillv2_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILLV2_INFER_H +#define MINDSPORE_NNACL_FILLV2_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILLV2_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c new file mode 100644 index 00000000..cac74318 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int output_shape_size = inputs[1]->shape_[0]; + if (inputs[1]->data_ == NULL || output_shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + SetShapeArray(output, (int *)(inputs[1]->data_), (size_t)output_shape_size); + return NNACL_OK; +} + +REG_INFER(FlattenGrad, PrimType_FlattenGrad, FlattenGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h new file mode 100644 index 00000000..7fa843c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H +#define MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c new file mode 100644 index 00000000..b154952a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/flatten_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/flatten_parameter.h" + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ <= 0 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + FlattenParameter *param = (FlattenParameter *)parameter; + int axis = param->axis_; + // The value for axis must be in the range[-r, r], where r is + // the rank of the input tensor.Negative value means counting + // dimensions from the back. + axis = axis < 0 ? (int)input_shape_size - axis : axis; + if (axis >= (int)input_shape_size) { + return NNACL_ERR; + } + int output_shape[2]; + output_shape[0] = axis == 0 ? 1 : input_shape[0]; + for (size_t i = 1; i < (size_t)axis; i++) { + output_shape[0] *= input_shape[i]; + } + output_shape[1] = input_shape[axis]; + for (size_t i = (size_t)axis + 1; i < input_shape_size; i++) { + output_shape[1] *= input_shape[i]; + } + SetShapeArray(output, output_shape, 2); + return NNACL_OK; +} + +REG_INFER(Flatten, PrimType_Flatten, FlattenInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h new file mode 100644 index 00000000..fc7671dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/flatten_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_INFER_H +#define MINDSPORE_NNACL_FLATTEN_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c new file mode 100644 index 00000000..6c720a2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.c @@ -0,0 +1,67 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/format_transpose_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/format_transpose_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + FormatTransposeParameter *param = (FormatTransposeParameter *)parameter; + output->format_ = (int)(param->dst_format_); + output->data_type_ = input->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != DIMENSION_4D) { + SetShapeArray(output, input->shape_, input->shape_size_); + return NNACL_OK; + } + + int input_b = NNACLGetBatch(input); + int input_h = NNACLGetHeight(input); + int input_w = NNACLGetWidth(input); + int input_c = NNACLGetChannel(input); + + // set output shape + int out_shape[MAX_SHAPE_SIZE] = {0}; + out_shape[DIMENSION_0D] = input_b; + if (param->dst_format_ == Format_NCHW || param->dst_format_ == Format_NC4HW4 || param->dst_format_ == Format_NC8HW8) { + out_shape[DIMENSION_1D] = input_c; + out_shape[DIMENSION_2D] = input_h; + out_shape[DIMENSION_3D] = input_w; + } else if (param->dst_format_ == Format_NHWC) { + out_shape[DIMENSION_1D] = input_h; + out_shape[DIMENSION_2D] = input_w; + out_shape[DIMENSION_3D] = input_c; + } else { + return NNACL_ERR; + } + + SetShapeArray(output, out_shape, input->shape_size_); + return NNACL_OK; +} + +REG_INFER(FormatTranspose, PrimType_FormatTranspose, FormatTransposeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h new file mode 100644 index 00000000..b8bb644f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/format_transpose_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c new file mode 100644 index 00000000..a6d354ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.c @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fse_decoder_infer.h" +#include "nnacl_c/infer/infer_register.h" + +size_t kInputSize = 7; + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, kInputSize, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *cen_input = inputs[4]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, cen_input); + + return NNACL_OK; +} + +REG_INFER(FseDecode, PrimType_Inner_FseDecode, FseDecoderInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h new file mode 100644 index 00000000..2f93ba44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fse_decoder_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c new file mode 100644 index 00000000..e9ffefde --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/full_connection_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FullConnectionInferPreJudge(const MatMulParameter *param, size_t inputs_size, const TensorC *input0) { + if ((param->has_bias_ && inputs_size != 3) || (!param->has_bias_ && inputs_size != 2)) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->use_axis_ && (param->axis_ < 1 || param->axis_ > (int)(input0->shape_size_))) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int pre_judge = FullConnectionInferPreJudge(param, inputs_size, input0); + if (pre_judge != NNACL_OK) { + return pre_judge; + } + int new_k = 1; + if (param->use_axis_) { + for (size_t i = (size_t)(param->axis_); i < input0->shape_size_; ++i) { + new_k *= input0->shape_[i]; + } + if (new_k != input1->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } else { + new_k = input1->shape_[1]; + } + if (param->has_bias_) { + if (inputs[2]->shape_[0] != input1->shape_[0]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (inputs[0]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (param->use_axis_) { + out_shape_size = (size_t)(param->axis_) + 1; + out_shape[param->axis_] = input1->shape_[0]; + } else { + int total = 1; + for (size_t i = 0; i < input0->shape_size_; ++i) { + total *= input0->shape_[i]; + } + out_shape_size = 2; + if (new_k == 0) { + return NNACL_ERR; + } + int batch_size = total / new_k; + out_shape[0] = batch_size; + out_shape[1] = input1->shape_[0]; + } + SetShapeArray(output, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(FullConnection, PrimType_FullConnection, FullConnectionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h new file mode 100644 index 00000000..18cb1c7f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/full_connection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FULL_CONNECTION_INFER_H +#define MINDSPORE_NNACL_FULL_CONNECTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FULL_CONNECTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c new file mode 100644 index 00000000..0a8b247d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + for (size_t i = 0; i < inputs_size; i++) { + if (outputs_size <= i) { + break; + } + SetShapeTensor(outputs[i], inputs[i]); + SetDataTypeFormat(outputs[i], inputs[i]); + } + if (outputs_size > 5) { + SetDataTypeFormat(outputs[5], inputs[0]); + outputs[5]->shape_size_ = 1; + outputs[5]->shape_[0] = 1; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(FusedBatchNorm, PrimType_FusedBatchNorm, FusedBatchNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h new file mode 100644 index 00000000..9279dba1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/fused_batchnorm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H +#define MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c new file mode 100644 index 00000000..2cd90d21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_d_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const int input_size_limit = 3; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *index = inputs[2]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, index); + return NNACL_OK; +} + +REG_INFER(GatherD, PrimType_GatherD, GatherDInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h new file mode 100644 index 00000000..0b600099 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_d_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_GATHER_D_INFER_H +#define MINDSPORE_NNACL_GATHER_D_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_D_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c new file mode 100644 index 00000000..653e1839 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if ((input->data_type_ == kNumberTypeInt8 || input->data_type_ == kNumberTypeInt16) && + (parameter->quant_type_ == Quant_QuantWeight || parameter->quant_type_ == Quant_QuantDynamic)) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[2]->data_ == NULL) { + return NNACL_NULL_PTR; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int axis = *((int *)inputs[2]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Gather, PrimType_Gather, GatherInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h new file mode 100644 index 00000000..25cafb2f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_INFER_H +#define MINDSPORE_NNACL_GATHER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c new file mode 100644 index 00000000..27512661 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_rank = (int)(input->shape_size_); + int indices_rank = (int)(indices->shape_size_); + for (int i = 0; i < indices_rank; i++) { + NNACL_CHECK_FALSE(indices->shape_[i] == 0, NNACL_ERR); + } + if (indices->shape_[indices_rank - 1] > in_rank) { + return NNACL_OK; + } + int i = 0; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + for (i = 0; i < indices_rank - 1; ++i) { + ShapePush(out_shape, &out_shape_size, indices->shape_[i]); + } + for (i = indices->shape_[indices_rank - 1]; i < in_rank; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(GatherNd, PrimType_GatherNd, GatherNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h new file mode 100644 index 00000000..f2a102fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gather_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_ND_INFER_H +#define MINDSPORE_NNACL_GATHER_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c new file mode 100644 index 00000000..9ff6ecb4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/glu_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/glu_parameter.h" + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + GluParameter *param = (GluParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ >= (int)input->shape_size_ || (param->axis_ < 0 && ((int)input->shape_size_ + param->axis_) < 0)) { + return NNACL_ERR; + } + int axis = param->axis_ > 0 ? param->axis_ : (int)input->shape_size_ + param->axis_; + if (axis < 0 || axis >= MAX_SHAPE_SIZE) { + return NNACL_BUFFER_OVERFLOW; + } + output->shape_[axis] /= 2; + return NNACL_OK; +} + +REG_INFER(GLU, PrimType_GLU, GluInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h new file mode 100644 index 00000000..a32b3487 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/glu_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GLU_INFER_H +#define MINDSPORE_NNACL_GLU_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GLU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c new file mode 100644 index 00000000..25366bb6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/grid_sampler_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/grid_sampler_parameter.h" + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != inputs[1]->shape_size_) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (input->shape_size_ < DIMENSION_4D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + for (size_t i = DIMENSION_2D; i < input->shape_size_; ++i) { + output->shape_[i] = inputs[1]->shape_[i - 1]; + } + return NNACL_OK; +} + +REG_INFER(GridSampler, PrimType_Inner_GridSampler, GridSamplerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h new file mode 100644 index 00000000..6110b83a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/grid_sampler_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRID_SAMPLER_INFER_H +#define MINDSPORE_NNACL_GRID_SAMPLER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRID_SAMPLER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c new file mode 100644 index 00000000..d7fc86de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SetDataTypeFormat(out, in0); + + size_t shape_size = in0->shape_size_; + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int shape_[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < shape_size; i++) { + shape_[i] = in0->shape_[i]; + } + SetShapeArray(out, shape_, shape_size); + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h new file mode 100644 index 00000000..e807f484 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c new file mode 100644 index 00000000..69615493 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/group_norm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(GroupNorm, PrimType_GroupNormFusion, GroupNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h new file mode 100644 index 00000000..c9f2e245 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/group_norm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c new file mode 100644 index 00000000..310bad75 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/gru_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 5, 6, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *weight_gate = inputs[1]; + const TensorC *weight_recurrence = inputs[2]; + const TensorC *bias = inputs[3]; + TensorC *output = outputs[0]; + for (int i = 0; i < 2; i++) { + SetDataTypeFormat(outputs[i], input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *in_shape = input->shape_; // seq_len, batch, input_size + const int *w_gate_shape = weight_gate->shape_; // num_direction, hidden_size * 3, input_size + const int *w_recu_shape = weight_recurrence->shape_; // num_direction, hidden_size * 3, hidden_size + const int *bias_shape = bias->shape_; // num_direction, hidden_size * 6 + if (input->shape_size_ != 3 || weight_gate->shape_size_ != 3 || weight_recurrence->shape_size_ != 3) { + return NNACL_ERR; + } + if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { + return NNACL_ERR; + } + if (inputs_size == 6) { + const int *seq_len_shape = inputs[5]->shape_; + if (seq_len_shape[0] > 1) { + return NNACL_ERR; + } + if (inputs[5]->shape_size_ != 1 && seq_len_shape[0] != in_shape[1]) { + return NNACL_ERR; + } + } + + int hidden_size = w_gate_shape[1] / 3; + // set output + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, input->shape_size_); + out_shape[2] = hidden_size; + + GruParameter *param = (GruParameter *)parameter; + if (param->bidirectional_) { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + // set hidden state + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + ShapeSet(state_shape, &state_shape_size, in_shape, input->shape_size_); + state_shape[0] = param->bidirectional_ ? 2 : 1; + state_shape[2] = hidden_size; + SetShapeArray(outputs[1], state_shape, state_shape_size); + return NNACL_OK; +} + +REG_INFER(GRU, PrimType_GRU, GruInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h new file mode 100644 index 00000000..fc57baf2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/gru_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRU_INFER_H +#define MINDSPORE_NNACL_GRU_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRU_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h new file mode 100644 index 00000000..c22403e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_H_ +#define MINDSPORE_NNACL_INFER_INFER_H_ + +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef int (*InferShape)(const TensorC *const *inputs, size_t input_size, TensorC **outputs, size_t output_size, + OpParameter *parameter); + +InferShape GetInferFunc(int prim_type); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c new file mode 100644 index 00000000..f8ec26dc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/infer_register.h" + +#ifdef _MSC_VER +#include "nnacl_c/infer/activation_grad_infer.h" +#include "nnacl_c/infer/adam_infer.h" +#include "nnacl_c/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/add_sub_grad_infer.h" +#include "nnacl_c/infer/addn_infer.h" +#include "nnacl_c/infer/affine_infer.h" +#include "nnacl_c/infer/all_gather_infer.h" +#include "nnacl_c/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" +#include "nnacl_c/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/arithmetic_grad_infer.h" +#include "nnacl_c/infer/arithmetic_infer.h" +#include "nnacl_c/infer/assert_op_infer.h" +#include "nnacl_c/infer/assign_add_infer.h" +#include "nnacl_c/infer/assign_infer.h" +#include "nnacl_c/infer/attention_infer.h" +#include "nnacl_c/infer/encoder_layer_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/bias_grad_infer.h" +#include "nnacl_c/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/bn_grad_infer.h" +#include "nnacl_c/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/cast_infer.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/infer/concat_infer.h" +#include "nnacl_c/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/decoder_layer_infer.h" + +#ifdef MSLITE_ENABLE_CONTROLFLOW +#include "nnacl_c/infer/control/tensor_array_infer.h" +#include "nnacl_c/infer/control/tensor_array_read_infer.h" +#include "nnacl_c/infer/control/tensor_array_write_infer.h" +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" +#endif +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/conv2d_infer.h" +#include "nnacl_c/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/infer/deconv2d_infer.h" +#include "nnacl_c/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/dropout_infer.h" +#include "nnacl_c/infer/dynamic_quant_infer.h" +#include "nnacl_c/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/expand_dims_infer.h" +#include "nnacl_c/infer/fft_imag_infer.h" +#include "nnacl_c/infer/fft_real_infer.h" +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/infer/fillv2_infer.h" +#include "nnacl_c/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/flatten_infer.h" +#include "nnacl_c/infer/full_connection_infer.h" +#include "nnacl_c/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/gather_infer.h" +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/infer/glu_infer.h" +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" +#include "nnacl_c/infer/gru_infer.h" +#include "nnacl_c/infer/instance_norm_infer.h" +#include "nnacl_c/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/layer_norm_grad_infer.h" +#include "nnacl_c/infer/layer_norm_infer.h" +#include "nnacl_c/infer/lin_space_infer.h" +#include "nnacl_c/infer/log_softmax_infer.h" +#include "nnacl_c/infer/lstm_grad_data_infer.h" +#include "nnacl_c/infer/lstm_grad_infer.h" +#include "nnacl_c/infer/lstm_grad_weight_infer.h" +#include "nnacl_c/infer/lstm_infer.h" +#include "nnacl_c/infer/matmul_infer.h" +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" +#include "nnacl_c/infer/nllloss_grad_infer.h" +#include "nnacl_c/infer/nllloss_infer.h" +#include "nnacl_c/infer/non_max_suppression_infer.h" +#include "nnacl_c/infer/one_hot_infer.h" +#include "nnacl_c/infer/pad_infer.h" +#include "nnacl_c/infer/pooling_grad_infer.h" +#include "nnacl_c/infer/pooling_infer.h" +#include "nnacl_c/infer/power_infer.h" +#include "nnacl_c/infer/prior_box_infer.h" +#include "nnacl_c/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/ragged_range_infer.h" +#include "nnacl_c/infer/random_normal_infer.h" +#include "nnacl_c/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/range_infer.h" +#include "nnacl_c/infer/rank_infer.h" +#include "nnacl_c/infer/reduce_infer.h" +#include "nnacl_c/infer/reduce_scatter_infer.h" +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/infer/resize_grad_infer.h" +#include "nnacl_c/infer/resize_infer.h" +#include "nnacl_c/infer/rfft_infer.h" +#include "nnacl_c/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/scatter_nd_update_infer.h" +#include "nnacl_c/infer/select_infer.h" +#include "nnacl_c/infer/sgd_infer.h" +#include "nnacl_c/infer/invalid_infer.h" +#ifndef RUNTIME_PASS_CLIP +#include "nnacl_c/infer/shape_fusion_infer.h" +#endif +#include "nnacl_c/infer/shape_infer.h" +#include "nnacl_c/infer/size_infer.h" +#include "nnacl_c/infer/slice_infer.h" +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/softmax_infer.h" +#include "nnacl_c/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/space_to_batch_nd_infer.h" +#include "nnacl_c/infer/space_to_depth_infer.h" +#include "nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl_c/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/splice_infer.h" +#include "nnacl_c/infer/split_infer.h" +#include "nnacl_c/infer/split_with_over_lap_infer.h" +#include "nnacl_c/infer/squeeze_infer.h" +#include "nnacl_c/infer/stack_infer.h" +#include "nnacl_c/infer/strided_slice_grad_infer.h" +#include "nnacl_c/infer/strided_slice_infer.h" +#ifdef MSLITE_ENABLE_STRING_KERNEL +#include "nnacl_c/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/string/skip_gram_infer.h" +#endif +#include "nnacl_c/infer/tile_infer.h" +#include "nnacl_c/infer/topk_infer.h" +#include "nnacl_c/infer/transpose_infer.h" +#include "nnacl_c/infer/uniform_real_infer.h" +#include "nnacl_c/infer/unique_infer.h" +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/infer/unstack_infer.h" +#include "nnacl_c/infer/where_infer.h" +#include "nnacl_c/infer/isfinite_infer.h" +#include "nnacl_c/infer/fse_decoder_infer.h" +#include "nnacl_c/infer/custom_gru_infer.h" + +InferShape g_infer_func[PrimType_MAX] = {0}; +InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +void RegAllInferFunc1() { + g_infer_func[PrimType_NONE] = NULL; + g_infer_func[PrimType_Abs] = CommonInferShape; + g_infer_func[PrimType_AbsGrad] = CommonGradInferShape; + g_infer_func[PrimType_Activation] = CommonInferShape; + g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape; + g_infer_func[PrimType_Adam] = AdamInferShape; + g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape; + g_infer_func[PrimType_AdderFusion] = Conv2dInferShape; + g_infer_func[PrimType_AddFusion] = ArithmeticInferShape; + g_infer_func[PrimType_AddGrad] = AddSubGradInferShape; + g_infer_func[PrimType_AddN] = AddnInferShape; + g_infer_func[PrimType_Affine] = AffineInferShape; + g_infer_func[PrimType_All] = NULL; + g_infer_func[PrimType_AllGather] = AllGatherInferShape; + g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape; + g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_Assert] = AssertOpInferShape; + g_infer_func[PrimType_Assign] = AssignInferShape; + g_infer_func[PrimType_AssignAdd] = AssignAddInferShape; + g_infer_func[PrimType_Attention] = AttentionInferShape; + g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape; + g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_BatchNorm] = CommonInferShape; + g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape; + g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape; + g_infer_func[PrimType_BatchToSpaceND] = NULL; + g_infer_func[PrimType_BiasAdd] = ArithmeticInferShape; + g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape; + g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape; + g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape; + g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape; + g_infer_func[PrimType_Call] = InvalidInferShape; + g_infer_func[PrimType_Cast] = CastInferShape; + g_infer_func[PrimType_Ceil] = CommonInferShape; + g_infer_func[PrimType_Clip] = CommonInferShape; + g_infer_func[PrimType_Concat] = ConcatInferShape; + g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape; + g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape; + g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape; + g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape; + g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape; + g_infer_func[PrimType_Cos] = CommonInferShape; + g_infer_func[PrimType_Crop] = CropInferShape; + g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape; + g_infer_func[PrimType_CumSum] = CumsumInferShape; + g_infer_func[PrimType_Custom] = NULL; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape; +#endif +} + +void RegAllInferFunc2() { +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape; + g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape; +#endif + g_infer_func[PrimType_DeConv2DGradFilter] = NULL; + g_infer_func[PrimType_Depend] = CommonInferShape; + g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape; + g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape; + g_infer_func[PrimType_DivFusion] = ArithmeticInferShape; + g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Dropout] = DropoutInferShape; + g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape; + g_infer_func[PrimType_DynamicQuant] = DynamicQuantInferShape; + g_infer_func[PrimType_Eltwise] = ArithmeticInferShape; + g_infer_func[PrimType_Elu] = CommonInferShape; + g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape; + g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape; + g_infer_func[PrimType_Erf] = CommonInferShape; + g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape; + g_infer_func[PrimType_ExpFusion] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL; + g_infer_func[PrimType_FftImag] = FftImagInferShape; + g_infer_func[PrimType_FftReal] = FftRealInferShape; + g_infer_func[PrimType_Fill] = FillInferShape; + g_infer_func[PrimType_FillV2] = FillInferShape; + g_infer_func[PrimType_Flatten] = FlattenInferShape; + g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape; + g_infer_func[PrimType_Floor] = CommonInferShapeWithOneInput; + g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape; + g_infer_func[PrimType_FloorMod] = ArithmeticInferShape; + g_infer_func[PrimType_FullConnection] = FullConnectionInferShape; + g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape; + g_infer_func[PrimType_Gather] = GatherInferShape; + g_infer_func[PrimType_GatherNd] = GatherNdInferShape; + g_infer_func[PrimType_GenOP] = NULL; + g_infer_func[PrimType_GLU] = GluInferShape; + g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GRU] = GruInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape; +#endif + g_infer_func[PrimType_InstanceNorm] = InstanceNormInferShape; + g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape; + g_infer_func[PrimType_IsFinite] = IsFiniteInferShape; + g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape; + g_infer_func[PrimType_LayerNormFusion] = LayerNormInferShape; + g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape; + g_infer_func[PrimType_LeakyRelu] = CommonInferShape; + g_infer_func[PrimType_Less] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LinSpace] = LinSpaceInferShape; +} + +void RegAllInferFunc3() { + g_infer_func[PrimType_Log] = CommonInferShape; + g_infer_func[PrimType_LogGrad] = CommonGradInferShape; + g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape; + g_infer_func[PrimType_LogicalNot] = CommonInferShape; + g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape; + g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape; + g_infer_func[PrimType_LpNormalization] = NULL; + g_infer_func[PrimType_LRN] = CommonInferShapeWithNHWC; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_LshProjection] = LshProjectionInferShape; +#endif + g_infer_func[PrimType_LSTM] = LstmInferShape; + g_infer_func[PrimType_LSTMGrad] = LstmGradInferShape; + g_infer_func[PrimType_LSTMGradData] = LstmGradDataInferShape; + g_infer_func[PrimType_LSTMGradWeight] = LstmGradWeightInferShape; + g_infer_func[PrimType_MatMulFusion] = MatmulInferShape; + g_infer_func[PrimType_Maximum] = ArithmeticInferShape; + g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_SwitchLayer] = InvalidInferShape; + g_infer_func[PrimType_Mfcc] = MfccInferShape; + g_infer_func[PrimType_MIN] = NULL; + g_infer_func[PrimType_Minimum] = ArithmeticInferShape; + g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_Mod] = ArithmeticInferShape; + g_infer_func[PrimType_MulFusion] = ArithmeticInferShape; + g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Neg] = CommonInferShape; + g_infer_func[PrimType_NegGrad] = CommonGradInferShape; + g_infer_func[PrimType_NLLLoss] = NLLLossInferShape; + g_infer_func[PrimType_NLLLossGrad] = NLLLossGradInferShape; + g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape; + g_infer_func[PrimType_NonZero] = NULL; + g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_OneHot] = OneHotInferShape; + g_infer_func[PrimType_OnesLike] = NULL; + g_infer_func[PrimType_PadFusion] = PadInferShape; + g_infer_func[PrimType_PartialFusion] = InvalidInferShape; + g_infer_func[PrimType_PowerGrad] = CommonGradInferShape; + g_infer_func[PrimType_PowFusion] = PowerInferShape; + g_infer_func[PrimType_PReLUFusion] = CommonInferShape; + g_infer_func[PrimType_PriorBox] = PriorBoxInferShape; + g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape; + g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape; + g_infer_func[PrimType_RandomNormal] = RandomNormalInferShape; + g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape; + g_infer_func[PrimType_Range] = RangeInferShape; + g_infer_func[PrimType_Rank] = RankInferShape; +} + +void RegAllInferFunc4() { + g_infer_func[PrimType_RealDiv] = ArithmeticInferShape; + g_infer_func[PrimType_Reciprocal] = CommonInferShape; + g_infer_func[PrimType_ReduceFusion] = ReduceInferShape; + g_infer_func[PrimType_ReduceScatter] = ReduceScatterInferShape; + g_infer_func[PrimType_Reshape] = ReshapeInferShape; + g_infer_func[PrimType_Resize] = ResizeInferShape; + g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape; + g_infer_func[PrimType_ReverseSequence] = CommonInferShape; + g_infer_func[PrimType_ReverseV2] = CommonInferShape; + g_infer_func[PrimType_Rfft] = RfftInferShape; + g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape; + g_infer_func[PrimType_Round] = CommonInferShape; + g_infer_func[PrimType_Rsqrt] = CommonInferShape; + g_infer_func[PrimType_RsqrtGrad] = NULL; + g_infer_func[PrimType_ScaleFusion] = CommonInferShape; + g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape; + g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_TensorScatterAdd] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_Select] = SelectInferShape; + g_infer_func[PrimType_SGD] = SgdInferShape; + g_infer_func[PrimType_Shape] = ShapeInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape; + g_infer_func[PrimType_Sin] = CommonInferShape; + g_infer_func[PrimType_Size] = SizeInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_SkipGram] = SkipGramInferShape; +#endif + g_infer_func[PrimType_SliceFusion] = SliceInferShape; + g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape; + g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape; + g_infer_func[PrimType_Softmax] = SoftMaxInferShape; + g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape; + g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape; + g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape; + g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape; + g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape; + g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape; + g_infer_func[PrimType_Splice] = SpliceInferShape; + g_infer_func[PrimType_Split] = SplitInferShape; + g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape; + g_infer_func[PrimType_Sqrt] = CommonInferShape; + g_infer_func[PrimType_SqrtGrad] = NULL; + g_infer_func[PrimType_Square] = CommonInferShape; + g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape; + g_infer_func[PrimType_Squeeze] = SqueezeInferShape; + g_infer_func[PrimType_Stack] = StackInferShape; + g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape; + g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape; + g_infer_func[PrimType_SubFusion] = ArithmeticInferShape; + g_infer_func[PrimType_SubGrad] = AddSubGradInferShape; +} + +void RegAllInferFunc5() { + g_infer_func[PrimType_Switch] = InvalidInferShape; +#ifdef MSLITE_ENABLE_CONTROLFLOW + g_infer_func[PrimType_TensorArray] = TensorArrayInferShape; + g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape; + g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape; + g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape; + g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape; + g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape; + g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape; + g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape; +#endif + g_infer_func[PrimType_TileFusion] = TileInferShape; + g_infer_func[PrimType_TopKFusion] = TopKInferShape; + g_infer_func[PrimType_Transpose] = TransposeInferShape; + g_infer_func[PrimType_UniformReal] = UniformRealInferShape; + g_infer_func[PrimType_Unique] = UniqueInferShape; + g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape; + g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape; + g_infer_func[PrimType_Unstack] = UnstackInferShape; + g_infer_func[PrimType_Where] = WhereInferShape; + g_infer_func[PrimType_ZerosLike] = CommonInferShape; + + // fused operators. + g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL; + g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL; +#ifndef RUNTIME_PASS_CLIP + g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape; + g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_DecoderLayer - PrimType_InnerOpMin] = DecoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_FseDecode - PrimType_InnerOpMin] = FseDecoderInferShape; +#endif + g_inner_op_infer_func[PrimType_Inner_CustomGru - PrimType_InnerOpMin] = CustomGruInferShape; + g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL; +} + +#else +__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX] = {0}; +__attribute__((init_priority(101))) InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +#endif // _MSC_VER + +InferShape GetInferFunc(int prim_type) { +#ifdef _MSC_VER + if (g_infer_func[PrimType_Abs] == NULL) { + RegAllInferFunc1(); + RegAllInferFunc2(); + RegAllInferFunc3(); + RegAllInferFunc4(); + RegAllInferFunc5(); + } +#endif + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + return g_infer_func[prim_type]; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin]; + } + return NULL; +} + +void RegInfer(int prim_type, InferShape func) { + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + g_infer_func[prim_type] = func; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h new file mode 100644 index 00000000..4a43a24f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ +#define MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ + +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RegInfer(int prim_type, InferShape func); + +#ifdef _MSC_VER +#define REG_INFER(op, type, func) +#else +#define REG_INFER(op, type, func) \ + __attribute__((constructor(102))) void Reg##op##Infer() { RegInfer(type, func); } +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c new file mode 100644 index 00000000..41887cab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.c @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/instance_norm_infer.h" +#include "nnacl_c/infer/crop_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, inputs[0]); + if (output->format_ == Format_NC4HW4) { + output->format_ = Format_NHWC; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, inputs[0]); + if (inputs[0]->format_ != Format_NC4HW4) { + return NNACL_OK; + } + if (output->shape_size_ <= DIMENSION_2D) { + return NNACL_OK; + } + int channel = output->shape_[1]; + ShapeErase(output->shape_, &output->shape_size_, 1); + ShapePush(output->shape_, &output->shape_size_, channel); + return NNACL_OK; +} +REG_INFER(InstanceNorm, PrimType_InstanceNorm, InstanceNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h new file mode 100644 index 00000000..cc90bad4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/instance_norm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INSTANCE_NORM_INFER_H +#define MINDSPORE_NNACL_INSTANCE_NORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INSTANCE_NORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c new file mode 100644 index 00000000..11be029d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/invalid_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(PartialFusion, PrimType_PartialFusion, InvalidInferShape) +REG_INFER(Switch, PrimType_Switch, InvalidInferShape) +REG_INFER(Call, PrimType_Call, InvalidInferShape) +REG_INFER(SwitchLayer, PrimType_SwitchLayer, InvalidInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h new file mode 100644 index 00000000..e9abbbbb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invalid_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVALID_INFER_H +#define MINDSPORE_NNACL_INVALID_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVALID_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c new file mode 100644 index 00000000..db56526b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + if (input->shape_size_ != 1) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(InvertPermutation, PrimType_InvertPermutation, InvertPermutationInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h new file mode 100644 index 00000000..8f5a8074 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/invert_permutation_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H +#define MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c new file mode 100644 index 00000000..9c11207a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/isfinite_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < input->shape_size_; i++) { + output->shape_[i] = input->shape_[i]; + } + output->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(IsFinite, PrimType_IsFinite, IsFiniteInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h new file mode 100644 index 00000000..46b6802d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/isfinite_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ISFINITE_INFER_H_ +#define MINDSPORE_NNACL_ISFINITE_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ISFINITE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c new file mode 100644 index 00000000..c3ea45ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.c @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/layer_norm_grad_infer.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/layernormgrad_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LayerNormGradParameter *param = (LayerNormGradParameter *)parameter; + const TensorC *input_x = inputs[0]; + TensorC *output_dx = outputs[0]; + TensorC *output_dg = outputs[1]; + TensorC *output_db = outputs[2]; + SetDataTypeFormat(output_dx, input_x); + SetDataTypeFormat(output_dg, input_x); + SetDataTypeFormat(output_db, input_x); + SetShapeTensor(output_dx, input_x); + int begin_params_axis = param->begin_params_axis_; + if (param->begin_params_axis_ < 0) { + begin_params_axis += (int)(input_x->shape_size_); + } + size_t size = 0; + if (input_x->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + for (int i = begin_params_axis; i < input_x->shape_size_; i++) { + if (size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + output_dg->shape_[size] = input_x->shape_[i]; + output_db->shape_[size] = input_x->shape_[i]; + size++; + } + output_db->shape_size_ = size; + output_dg->shape_size_ = size; + return NNACL_OK; +} + +REG_INFER(LayerNormGrad, PrimType_LayerNormGrad, LayerNormGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h new file mode 100644 index 00000000..dc884dc5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c new file mode 100644 index 00000000..24a9021c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.c @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/layer_norm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if ((inputs_size != 1 && inputs_size != 3) || (outputs_size != 1 && outputs_size != 3)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + LayerNormParameter *param = (LayerNormParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > COMM_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->begin_params_axis_ < (-1 * (int)(input->shape_size_)) || + param->begin_params_axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + param->begin_norm_axis_ = + param->begin_norm_axis_ < 0 ? param->begin_norm_axis_ + ((int)(input->shape_size_)) : param->begin_norm_axis_; + SetShapeTensor(output, input); + // take care of other outputs + if (outputs_size == 3) { + TensorC *output_mean = outputs[1]; + TensorC *output_var = outputs[2]; + SetDataTypeFormat(output_mean, input); + SetDataTypeFormat(output_var, input); + int size = 0; + NNACL_CHECK_TRUE_RET(param->begin_norm_axis_ <= MAX_SHAPE_SIZE, NNACL_ERR); + for (; size < param->begin_norm_axis_; size++) { + output_mean->shape_[size] = input->shape_[size]; + output_var->shape_[size] = input->shape_[size]; + } + output_mean->shape_size_ = (size_t)size; + output_var->shape_size_ = (size_t)size; + } + + return NNACL_OK; +} + +REG_INFER(LayerNormFusion, PrimType_LayerNormFusion, LayerNormInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h new file mode 100644 index 00000000..85d51d2b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/layer_norm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LAYER_NORM_INFER_H +#define MINDSPORE_NNACL_LAYER_NORM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/layer_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LAYER_NORM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c new file mode 100644 index 00000000..d774f9f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lin_space_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int *num = (int *)(inputs[2]->data_); + if (num == NULL) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = num[0]; + return NNACL_OK; +} + +REG_INFER(LinSpace, PrimType_LinSpace, LinSpaceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h new file mode 100644 index 00000000..1f5cf3fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lin_space_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LIN_SPACE_INFER_H +#define MINDSPORE_NNACL_LIN_SPACE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LIN_SPACE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c new file mode 100644 index 00000000..3d2ea4cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/log_softmax_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const int input_size_limit = 1; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > 5) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +REG_INFER(LogSoftmax, PrimType_LogSoftmax, LogSoftmaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h new file mode 100644 index 00000000..d320fb58 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/log_softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H +#define MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c new file mode 100644 index 00000000..a70323d2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_data_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 9, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LstmGradParameter *p = (LstmGradParameter *)parameter; + const TensorC *Y = inputs[SECOND_INPUT]; + const TensorC *H = inputs[THIRD_INPUT]; + const TensorC *C = inputs[FOURTH_INPUT]; + const TensorC *weight = inputs[FIFTH_INPUT]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], Y); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (Y->shape_size_ != C3NUM || weight->shape_size_ != C3NUM) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, p->input_size_); + + SetShapeArray(outputs[FIRST_INPUT], out_shape, C3NUM); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGradData, PrimType_LSTMGradData, LstmGradDataInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h new file mode 100644 index 00000000..e3a4885d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_data_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c new file mode 100644 index 00000000..124b80b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 11, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *H = inputs[1]; + const TensorC *C = inputs[2]; + const TensorC *weight = inputs[3]; + TensorC *output = outputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != 3 || weight->shape_size_ != 3) { + return NNACL_ERR; + } + + SetShapeArray(output, input->shape_, input->shape_size_); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + SetShapeArray(outputs[FOURTH_INPUT], weight->shape_, weight->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGrad, PrimType_LSTMGrad, LstmGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h new file mode 100644 index 00000000..5044ceca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c new file mode 100644 index 00000000..0e8dfc70 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_grad_weight_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *H = inputs[SECOND_INPUT]; + const TensorC *Y = inputs[THIRD_INPUT]; + + TensorC *output = outputs[FIRST_INPUT]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != C3NUM || H->shape_size_ != C3NUM || Y->shape_size_ != C3NUM) { + return NNACL_ERR; + } + LstmGradParameter *param = (LstmGradParameter *)parameter; + int has_bias = param->has_bias_; + int output_shape[3] = {0, 1, 1}; + int gate_size = 4 * param->hidden_size_; + output_shape[0] += gate_size * param->input_size_; + output_shape[0] += gate_size * param->hidden_size_; + if (has_bias) { + output_shape[0] += C2NUM * gate_size; + } + int dir_mul = (param->bidirectional_) ? C2NUM : C1NUM; + output_shape[0] *= dir_mul; + SetShapeArray(output, output_shape, C3NUM); + + return NNACL_OK; +} + +REG_INFER(LSTMGradWeight, PrimType_LSTMGradWeight, LstmGradWeightInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h new file mode 100644 index 00000000..d0ffa18f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_grad_weight_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c new file mode 100644 index 00000000..139eedfa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.c @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/lstm_infer.h" +#include "nnacl_c/infer/infer_register.h" + +static const int no_of_recorde_values = 5; + +int CheckInputShapeValid(const TensorC *const *inputs, size_t inputs_size, const LstmParameter *parameter) { + if (inputs_size < C6NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *weight_i = inputs[SECOND_INPUT]; + const TensorC *weight_g = inputs[THIRD_INPUT]; + const TensorC *bias = inputs[FOURTH_INPUT]; + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + const TensorC *cell_init = inputs[SIXTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(weight_i); + NNACL_CHECK_NULL_RETURN_ERR(weight_g); + NNACL_CHECK_NULL_RETURN_ERR(bias); + NNACL_CHECK_NULL_RETURN_ERR(hidden_init); + NNACL_CHECK_NULL_RETURN_ERR(cell_init); + NNACL_CHECK_TRUE_RET(input->shape_size_ == DIMENSION_3D && weight_i->shape_size_ == DIMENSION_3D && + weight_g->shape_size_ == DIMENSION_3D && bias->shape_size_ == DIMENSION_2D, + NNACL_ERR); + int batch = input->shape_[kNHWC_H]; + int input_size = input->shape_[kNHWC_W]; + int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM; + int out_size = hidden_size; + if (inputs_size == C7NUM) { + NNACL_CHECK_TRUE_RET(inputs[SEVENTH_INPUT]->shape_size_ == DIMENSION_3D, NNACL_INPUT_TENSOR_ERROR); + out_size = inputs[SEVENTH_INPUT]->shape_[kNHWC_H]; + } + bool bidirectional = parameter->bidirectional_; + int bidirection = bidirectional ? C2NUM : C1NUM; + NNACL_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_i->shape_[kNHWC_W] == input_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_g->shape_[kNHWC_W] == out_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR); + if (!bidirectional && hidden_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(hidden_init->shape_[kNHWC_N] == batch && hidden_init->shape_[kNHWC_H] == out_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(hidden_init->shape_size_ == DIMENSION_3D && hidden_init->shape_[kNHWC_N] == bidirection && + hidden_init->shape_[kNHWC_H] == batch && hidden_init->shape_[kNHWC_W] == out_size, + NNACL_ERR); + } + if (!bidirectional && cell_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(cell_init->shape_[kNHWC_N] == batch && cell_init->shape_[kNHWC_H] == hidden_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(cell_init->shape_size_ == DIMENSION_3D && cell_init->shape_[kNHWC_N] == bidirection && + cell_init->shape_[kNHWC_H] == batch && cell_init->shape_[kNHWC_W] == hidden_size, + NNACL_ERR); + } + return NNACL_OK; +} + +int InferFirstOutputMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + for (size_t i = 0; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + int out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + output->shape_[THIRD_INPUT] = (param->bidirectional_ ? C2NUM : 1) * out_size; + return NNACL_OK; +} + +int InferFirstOutputNonMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + int out_size = hidden_init->shape_[hidden_init->shape_size_ - 1]; + output->shape_[THIRD_INPUT] = out_size; + int direction = param->bidirectional_ ? C2NUM : C1NUM; + int ret = ShapeInsert(output->shape_, &output->shape_size_, 1, direction); + return ret; +} + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + LstmParameter *param = (LstmParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int hidden_size = 0; + int out_size = 0; + if (inputs_size == C4NUM) { + int ret = InferFirstOutputMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[THIRD_INPUT]->shape_[THIRD_INPUT]; + out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + } else { + int ret = InferFirstOutputNonMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[SIXTH_INPUT]->shape_[inputs[SIXTH_INPUT]->shape_size_ - 1]; + out_size = inputs[FIFTH_INPUT]->shape_[inputs[FIFTH_INPUT]->shape_size_ - 1]; + } + + int dir_multiplier = param->bidirectional_ ? C2NUM : C1NUM; + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + + ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); + state_shape[FIRST_INPUT] = dir_multiplier; + state_shape[THIRD_INPUT] = out_size; + SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size); + state_shape[THIRD_INPUT] = hidden_size; + SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size); + + if (outputs_size > DIMENSION_4D) { + int intermediate_states_shape[MAX_SHAPE_SIZE]; + const size_t intermediate_states_shape_size = 1; + int batch_size = input->shape_[SECOND_INPUT]; + int seq_len = input->shape_[FIRST_INPUT]; + intermediate_states_shape[FIRST_INPUT] = + batch_size * seq_len * dir_multiplier * (out_size + no_of_recorde_values * hidden_size); + SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size); + SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(LSTM, PrimType_LSTM, LstmInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h new file mode 100644 index 00000000..20392e1f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/lstm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_INFER_H +#define MINDSPORE_NNACL_LSTM_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c new file mode 100644 index 00000000..a8fd494f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/matmul_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +#define MIN_SHAPE_SIZE 2 + +int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, const int *bias_shape, + size_t bias_shape_size, const MatMulParameter *param) { + if (a_shape_size < MIN_SHAPE_SIZE || b_shape_size < MIN_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + int min_value = MSMIN(a_shape[i], b_shape[i]); + int max_value = MSMAX(a_shape[i], b_shape[i]); + if (min_value != 0 && max_value % min_value != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (param->a_transpose_) { + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - DIMENSION_2D]); + } + if (param->b_transpose_) { + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + if (bias_shape_size == DIMENSION_1D && bias_shape[0] != b_shape[b_shape_size - 1]) { + return NNACL_ERR; + } + if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CheckMatMulBias(int *shape, size_t dim_size) { + if (dim_size > 1) { + for (size_t i = 0; i < dim_size - 1; i++) { + if (shape[i] != DIMENSION_1D) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + int *shape_align = a_shape_size > b_shape_size ? b_shape : a_shape; + size_t *shape_size_align = a_shape_size > b_shape_size ? &b_shape_size : &a_shape_size; + int diff = abs((int)a_shape_size - (int)b_shape_size); + for (int i = 0; i < diff; ++i) { + ShapeInsert(shape_align, shape_size_align, 0, 1); + } + int bias_shape[MAX_AXIS_SIZE] = {0}; + size_t bias_shape_size = 0; + if (inputs_size == kInputSize2) { + TensorC *bias = (TensorC *)inputs[2]; + ShapeSet(bias_shape, &bias_shape_size, bias->shape_, bias->shape_size_); + NNACL_CHECK_TRUE_RET(CheckMatMulBias(bias_shape, bias_shape_size) == NNACL_OK, NNACL_ERR); + } + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int insert_ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + del_end = true; + } + int ret = CheckMatmulInputShape(a_shape, a_shape_size, b_shape, b_shape_size, bias_shape, bias_shape_size, param); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + c_shape[i] = MSMAX(a_shape[i], b_shape[i]); + } + + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + + TensorC *input = input1->data_ == NULL ? input1 : input0; // transfer the input which comes from the other node. + SetDataTypeFormat(output, input); + if (input->data_type_ == kNumberTypeInt8 && parameter->quant_type_ == Quant_QuantDynamic) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return SetShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(MatMul, PrimType_MatMulFusion, MatmulInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h new file mode 100644 index 00000000..f4d51329 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/matmul_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MATMUL_INFER_H +#define MINDSPORE_NNACL_MATMUL_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MATMUL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c new file mode 100644 index 00000000..cbeb70a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/infer/infer_register.h" + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x1 = inputs[0]; + const TensorC *x2 = inputs[1]; + const TensorC *dy = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE || dy->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)(param->ndim_); + param->in_elements_num1_ = (int)(param->ndim_); + param->out_elements_num_ = (int)(param->ndim_); + int fillDimNum0 = (int)(dy->shape_size_ - x1->shape_size_); + int fillDimNum1 = (int)(dy->shape_size_ - x2->shape_size_); + int j0 = 0; + int j1 = 0; + for (unsigned int i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = ((int)i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = ((int)i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaxMinGradInferShape) +REG_INFER(MinimumGrad, PrimType_MinimumGrad, MaxMinGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h new file mode 100644 index 00000000..b927f5a5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/max_min_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c new file mode 100644 index 00000000..1d9f27a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/mfcc_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 3) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[1]) != 1) { + return NNACL_ERR; + } + output->shape_size_ = 3; + output->shape_[0] = input->shape_[0]; + output->shape_[1] = input->shape_[1]; + MfccParameter *param = (MfccParameter *)parameter; + output->shape_[2] = param->dct_coeff_num_; + return NNACL_OK; +} + +REG_INFER(Mfcc, PrimType_Mfcc, MfccInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h new file mode 100644 index 00000000..c2b02349 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/mfcc_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MFCC_INFER_H +#define MINDSPORE_NNACL_MFCC_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct MfccParameter { + OpParameter op_parameter_; + int dct_coeff_num_; +} MfccParameter; + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MFCC_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c new file mode 100644 index 00000000..8ab148c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/nllloss_grad_infer.h" + +#include "nnacl_c/infer/infer_register.h" + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C5NUM, C1NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *loss_grad = inputs[1]; + const TensorC *labels = inputs[2]; + const TensorC *weight = inputs[3]; + const TensorC *total_weight = inputs[4]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM || + total_weight->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None && loss_grad->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->reduction_type_ != Reduction_None && loss_grad->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *logits_grad = outputs[0]; + SetDataTypeFormat(logits_grad, logits); + SetShapeTensor(logits_grad, logits); + return NNACL_OK; +} + +REG_INFER(NLLLossGrad, PrimType_NLLLossGrad, NLLLossGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h new file mode 100644 index 00000000..9fcb5f9d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_grad_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c new file mode 100644 index 00000000..ac1f1411 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/nllloss_infer.h" + +#include "nnacl_c/infer/infer_register.h" + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *labels = inputs[1]; + const TensorC *weight = inputs[2]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *loss = outputs[0]; + TensorC *total_weight = outputs[1]; + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None) { + SetShapeTensor(loss, labels); + } else { + loss->shape_size_ = 0; + } + total_weight->shape_size_ = 0; + SetDataTypeFormat(loss, logits); + SetDataTypeFormat(total_weight, logits); + return NNACL_OK; +} + +REG_INFER(NLLLoss, PrimType_NLLLoss, NLLLossInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h new file mode 100644 index 00000000..c9d01154 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/nllloss_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c new file mode 100644 index 00000000..5f6808e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/non_max_suppression_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeInt32; + output->format_ = input->format_; + return NNACL_INFER_INVALID; +} + +REG_INFER(NonMaxSuppression, PrimType_NonMaxSuppression, NonMaxSuppressionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h new file mode 100644 index 00000000..b802a88a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/non_max_suppression_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H +#define MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c new file mode 100644 index 00000000..e61a8e0c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/one_hot_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 4 && inputs_size != 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + const TensorC *depth_tensor = inputs[1]; + const TensorC *on_value = inputs[2]; + TensorC *output = outputs[0]; + const int *depth = (int *)(depth_tensor->data_); + if (depth == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(output, on_value); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + OneHotParameter *param = (OneHotParameter *)parameter; + int axis = param->axis_; + int input_rank = (int)(input->shape_size_); + if (axis < 0) { + axis += input_rank + 1; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int res_insert = ShapeInsert(output->shape_, &output->shape_size_, axis, *depth); + if (res_insert == NNACL_ERR) { + return NNACL_ERR; + } + + return NNACL_OK; +} + +REG_INFER(OneHot, PrimType_OneHot, OneHotInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h new file mode 100644 index 00000000..b5c0dddf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/one_hot_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ONE_HOT_INFER_H +#define MINDSPORE_NNACL_ONE_HOT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/one_hot_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ONE_HOT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c new file mode 100644 index 00000000..a0609ebb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + PadParameter *param = (PadParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > DEFAULT_PAD_NDIMS) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *paddings = inputs[1]; + int size = NNACLGetElementNum(paddings); + if (size > MAX_PAD_SIZE) { + return NNACL_PARAM_INVALID; + } + if (paddings->data_ == NULL) { + return NNACL_INFER_INVALID; + } + param->padding_length = size; + for (int i = 0; i < size; ++i) { + NNACL_CHECK_TRUE_RET(((int *)paddings->data_)[i] >= 0, NNACL_INFER_INVALID); + param->paddings_[i] = ((int *)paddings->data_)[i]; + } + + int output_shape[DEFAULT_PAD_NDIMS] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + int shape = input->shape_[i] + param->paddings_[2 * i] + param->paddings_[2 * i + 1]; + ShapePush(output_shape, &output_shape_size, shape); + } + + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Pad, PrimType_PadFusion, PadInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h new file mode 100644 index 00000000..b5d13882 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PAD_INFER_H +#define MINDSPORE_NNACL_PAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c new file mode 100644 index 00000000..6929016c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pooling_grad_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + PoolingParameter *param = (PoolingParameter *)parameter; + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + window_h = input_h; + window_w = input_w; + } + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_w_); + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_h_); + int output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (window_h - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (window_w - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } + SetDataTypeFormat(outputs[0], input); + SetShapeTensor(outputs[0], input); + return NNACL_OK; +} + +REG_INFER(AvgPoolGrad, PrimType_AvgPoolGrad, PoolingGradInferShape) +REG_INFER(MaxPoolGrad, PrimType_MaxPoolGrad, PoolingGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h new file mode 100644 index 00000000..80c13b25 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_GRAD_INFER_H +#define MINDSPORE_NNACL_POOLING_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c new file mode 100644 index 00000000..1d970cbf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/pooling_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int ComputePadList(PoolingParameter *param, int input_h, int input_w, int output_h, int output_w) { + if (param == NULL) { + return NNACL_NULL_PTR; + } + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->window_h_ - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->window_w_ - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + return NNACL_OK; +} + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + NNACL_CHECK_TRUE_RET(input->format_ == Format_NHWC, NNACL_FORMAT_ERROR); + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + PoolingParameter *param = (PoolingParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ < 3 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + param->window_h_ = window_h = input_h; + param->window_w_ = window_w = input_w; + } + int output_h = 0; + int output_w = 0; + if ((param->stride_h_ == 0 || param->stride_w_ == 0) && !param->global_) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + if (ComputePadList(param, input_h, input_w, output_h, output_w) != NNACL_OK) { + return NNACL_NULL_PTR; + } + } else { + int round_mode = (RoundType)param->round_type_; + if (round_mode == RoundType_Floor) { + output_h = floor((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = floor((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else if (round_mode == RoundType_Ceil) { + output_h = ceil((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = ceil((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else { + return NNACL_ERR; + } + } + int input_shape[MAX_SHAPE_SIZE]; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + input_shape[1] = output_h > 0 ? output_h : 1; + input_shape[2] = output_w > 0 ? output_w : 1; + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetShapeArray(output, input_shape, input_shape_size); + } + return NNACL_OK; +} + +REG_INFER(MaxPool, PrimType_MaxPoolFusion, PoolingInferShape) +REG_INFER(AvgPool, PrimType_AvgPoolFusion, PoolingInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h new file mode 100644 index 00000000..c5587c6e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_INFER_H +#define MINDSPORE_NNACL_POOLING_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c new file mode 100644 index 00000000..00db9d12 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/power_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x_tensor = inputs[0]; + TensorC *exp_tensor = NULL; + if (inputs_size == 2) { + exp_tensor = (TensorC *)inputs[1]; + PowParameter *param = (PowParameter *)parameter; + float *exp_data = (float *)(exp_tensor->data_); + if (exp_data == NULL) { + return NNACL_INFER_INVALID; + } + param->power_ = *exp_data; + } + TensorC *output_tensor = outputs[0]; + + SetDataTypeFormat(output_tensor, x_tensor); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (exp_tensor != NULL) { + bool exp_x_equal = ShapeEqual(exp_tensor->shape_, exp_tensor->shape_size_, x_tensor->shape_, x_tensor->shape_size_); + if (!exp_x_equal && NNACLGetElementNum(exp_tensor) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + SetShapeTensor(output_tensor, x_tensor); + return NNACL_OK; +} + +REG_INFER(Pow, PrimType_PowFusion, PowerInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h new file mode 100644 index 00000000..8395060e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/power_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POWER_INFER_H +#define MINDSPORE_NNACL_POWER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/pow_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POWER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c new file mode 100644 index 00000000..a49b2c38 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.c @@ -0,0 +1,87 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/prior_box_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + float different_aspect_ratios[MAX_SHAPE_SIZE * 2 + 1]; // NOTE: flip double the number + different_aspect_ratios[0] = 1.0; + int32_t different_aspect_ratios_size = 1; + + PriorBoxParameter *param = (PriorBoxParameter *)parameter; + float *aspect_ratios = param->aspect_ratios; + if (aspect_ratios == NULL) { + return NNACL_NULL_PTR; + } + int32_t aspect_ratios_size = param->aspect_ratios_size; + NNACL_CHECK_TRUE_RET(aspect_ratios_size <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int32_t i = 0; i < aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + if (fabsf(ratio) < EPSILON_VALUE) { + return NNACL_ERR; + } + + bool exist = false; + for (int32_t j = 0; j < different_aspect_ratios_size; j++) { + if (fabsf(ratio - different_aspect_ratios[j]) < EPSILON_VALUE) { + exist = true; + break; + } + } + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size] = ratio; + different_aspect_ratios_size++; + if (param->flip) { + different_aspect_ratios[different_aspect_ratios_size] = 1.0f / ratio; + different_aspect_ratios_size++; + } + } + } + + int32_t min_sizes_size = param->min_sizes_size; + int32_t max_sizes_size = param->max_sizes_size; + int32_t num_priors_box = min_sizes_size * different_aspect_ratios_size + max_sizes_size; + const int kPriorBoxPoints = 4; + const int kPriorBoxN = 1; + const int kPriorBoxW = 1; + const int kPriorBoxC = 2; + + int32_t h = NNACLGetHeight(input) * NNACLGetWidth(input) * num_priors_box * kPriorBoxPoints; + output->shape_size_ = 4; + output->shape_[0] = kPriorBoxN; + output->shape_[1] = h; + output->shape_[2] = kPriorBoxW; + output->shape_[3] = kPriorBoxC; + return NNACL_OK; +} + +REG_INFER(PriorBox, PrimType_PriorBox, PriorBoxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h new file mode 100644 index 00000000..a113415c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/prior_box_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PRIOR_BOX_INFER_H +#define MINDSPORE_NNACL_PRIOR_BOX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PRIOR_BOX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c new file mode 100644 index 00000000..d0c58a06 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + QuantDtypeCastParameter *param = (QuantDtypeCastParameter *)parameter; + output->data_type_ = param->dstT_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(QuantDTypeCast, PrimType_QuantDTypeCast, QuantDtypeCastInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h new file mode 100644 index 00000000..fba14604 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/quant_dtype_cast_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H +#define MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct QuantDtypeCastParameter { + OpParameter op_parameter_; + int srcT_; // deprecated + int dstT_; +} QuantDtypeCastParameter; + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c new file mode 100644 index 00000000..3b1c7aec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/infer/ragged_range_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int CheckInputTensor(const TensorC *const *inputs) { + if (inputs[0]->data_ == NULL || inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (inputs[0]->shape_size_ != 0 && inputs[0]->shape_size_ != 1) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int GetRows(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, bool deltas_is_scalar, + int *rows) { + NNACL_CHECK_NULL_RETURN_ERR(rows); + int sizes[3]; + int not_scalar_count = 0; + if (!starts_is_scalar) { + sizes[not_scalar_count++] = inputs[0]->shape_[0]; + } + if (!limits_is_scalar) { + sizes[not_scalar_count++] = inputs[1]->shape_[0]; + } + if (!deltas_is_scalar) { + sizes[not_scalar_count++] = inputs[2]->shape_[0]; + } + for (int i = 1; i < not_scalar_count; i++) { + if (sizes[i] != sizes[i - 1]) { + return NNACL_ERR; + } + } + *rows = not_scalar_count == 0 ? 1 : sizes[0]; + return NNACL_OK; +} + +int GetOutputValueElementNum(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, + bool deltas_is_scalar, int rows, int *output_value_element_num) { + int count = 0; + switch (inputs[0]->data_type_) { + case kNumberTypeInt32: { + int *starts = (int *)(inputs[0]->data_); + int *limits = (int *)(inputs[1]->data_); + int *deltas = (int *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + int start = starts_is_scalar ? starts[0] : starts[i]; + int limit = limits_is_scalar ? limits[0] : limits[i]; + int delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((int)(ceil((float)(limit - start) / delta)), 0); + } + } break; + case kNumberTypeFloat32: { + float *starts = (float *)(inputs[0]->data_); + float *limits = (float *)(inputs[1]->data_); + float *deltas = (float *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + float start = starts_is_scalar ? starts[0] : starts[i]; + float limit = limits_is_scalar ? limits[0] : limits[i]; + float delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((ceil((limit - start) / delta)), 0); + } + } break; + default: { + return NNACL_ERR; + } + } + *output_value_element_num = count; + return NNACL_OK; +} + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeInt32; + outputs[0]->format_ = inputs[0]->format_; + SetDataTypeFormat(outputs[1], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int ret = CheckInputTensor(inputs); + if (ret != NNACL_OK) { + return ret; + } + + bool starts_is_scalar = inputs[0]->shape_size_ == 0; + bool limits_is_scalar = inputs[1]->shape_size_ == 0; + bool deltas_is_scalar = inputs[2]->shape_size_ == 0; + int rows; + ret = GetRows(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, &rows); + if (ret != NNACL_OK) { + return ret; + } + int output_value_element_num; + ret = GetOutputValueElementNum(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, rows, + &output_value_element_num); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->shape_size_ = 1; + outputs[0]->shape_[0] = rows + 1; + outputs[1]->shape_size_ = 1; + outputs[1]->shape_[0] = output_value_element_num; + return NNACL_OK; +} + +REG_INFER(RaggedRange, PrimType_RaggedRange, RaggedRangeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h new file mode 100644 index 00000000..22613326 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/ragged_range_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RAGGED_RANGE_INFER_H +#define MINDSPORE_NNACL_RAGGED_RANGE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RAGGED_RANGE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c new file mode 100644 index 00000000..20d18626 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/random_normal_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = inputs[0]->data_type_; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(outputs[0], inputs[0]); + + return NNACL_OK; +} + +REG_INFER(RandomNormal, PrimType_RandomNormal, RandomNormalInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h new file mode 100644 index 00000000..5dce4607 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c new file mode 100644 index 00000000..5f214095 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < input_num; i++) { + ShapePush(output_shape, &output_shape_size, input_data[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + + return NNACL_OK; +} + +REG_INFER(RandomStandardNormal, PrimType_RandomStandardNormal, RandomStandardNormalInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h new file mode 100644 index 00000000..6a31082a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/random_standard_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c new file mode 100644 index 00000000..619c658f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/range_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/range_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = inputs_size == C3NUM ? input->data_type_ : kNumberTypeInt32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[FIRST_INPUT]) < 1) { + return NNACL_ERR; + } + int shape_size = 0; + if (inputs_size == C3NUM) { + NNACL_CHECK_FALSE(inputs[FIRST_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[SECOND_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[THIRD_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + if ((inputs[FIRST_INPUT]->data_type_ != inputs[SECOND_INPUT]->data_type_) || + (inputs[FIRST_INPUT]->data_type_ != inputs[THIRD_INPUT]->data_type_)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[SECOND_INPUT]) < 1 || NNACLGetElementNum(inputs[THIRD_INPUT]) < 1) { + return NNACL_ERR; + } + switch (inputs[0]->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: { + int start = *(int *)(inputs[0]->data_); + int limit = *(int *)(inputs[1]->data_); + int delta = *(int *)(inputs[2]->data_); + if (delta == 0) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + case kNumberTypeFloat32: + case kNumberTypeFloat: { + float start = *(float *)(inputs[0]->data_); + float limit = *(float *)(inputs[1]->data_); + float delta = *(float *)(inputs[2]->data_); + if (fabsf(delta) < EPSILON_VALUE) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + default: { + return NNACL_ERR; + } + } + } else { + RangeParameter *param = (RangeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->delta_ == 0) { + return NNACL_PARAM_INVALID; + } + shape_size = ceil((float)(param->limit_ - param->start_) / param->delta_); + } + + output->shape_size_ = 1; + output->shape_[0] = shape_size; + return NNACL_OK; +} + +REG_INFER(Range, PrimType_Range, RangeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h new file mode 100644 index 00000000..eb1401a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/range_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANGE_INFER_H +#define MINDSPORE_NNACL_RANGE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANGE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c new file mode 100644 index 00000000..2c6d9299 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/rank_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = 1; + return NNACL_OK; +} + +REG_INFER(Rank, PrimType_Rank, RankInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h new file mode 100644 index 00000000..5f7d2c46 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rank_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_INFER_H +#define MINDSPORE_NNACL_RANK_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c new file mode 100644 index 00000000..00e68b34 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.c @@ -0,0 +1,95 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_concat_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int DataTypeJudge2(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + int axis = C2NUM; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge2(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(ReduceConcatFusion, PrimType_Inner_ReduceConcatFusion, ReduceConcatFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h new file mode 100644 index 00000000..267855ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H +#define MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c new file mode 100644 index 00000000..010d3739 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.c @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) { + if (keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + ShapePush(out_shape, &out_shape_size, 1); + } + } + SetShapeArray(output, out_shape, out_shape_size); + output->data_type_ = input->data_type_; + return NNACL_OK; +} + +int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actual_axes, TensorC *output, int *out_shape, + size_t out_shape_size, bool keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + bool reduce_axis = false; + for (size_t idx = 0; idx < num_axes; ++idx) { + if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx]) + input->shape_size_ == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + ShapePush(out_shape, &out_shape_size, 1); + } + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +bool IsReduceAllAxes(const TensorC *const *inputs, size_t inputs_size) { + if (inputs_size == 1) { + return true; + } + // When axes not given, reduce op will have two input tensor by the old version converter_lite tool. + if (inputs_size == 2 && inputs[1]->shape_size_ == 1 && inputs[1]->shape_[0] == 0) { + return true; + } + return false; +} + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceParameter *param = (ReduceParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + bool keep_dims = param->keep_dims_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + const size_t out_shape_size = 0; + if (IsReduceAllAxes(inputs, inputs_size)) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + + // get axes from input tensor + const TensorC *axes_input = inputs[1]; + NNACL_CHECK_NULL_RETURN_ERR(axes_input->data_); + + int num_axes; + if (axes_input->shape_size_ == 1) { + num_axes = axes_input->shape_[0]; + } else if (axes_input->shape_size_ == 0) { + num_axes = 1; + } else { + return NNACL_ERR; + } + if (num_axes > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int rank = (int)(input->shape_size_); + if (rank > MAX_SHAPE_SIZE || rank < 0) { + return NNACL_ERR; + } + int actual_axes[MAX_SHAPE_SIZE] = {0}; + size_t actual_axes_size = 0; + int ret = GetInt32DataFromTensor(axes_input, actual_axes, &actual_axes_size); + if (ret != NNACL_OK) { + return ret; + } + + if (param->reduce_to_end_) { + if (num_axes != 1) { + return NNACL_ERR; + } + + if (actual_axes[0] < -1 * rank || actual_axes[0] >= rank) { + return NNACL_PARAM_INVALID; + } + int begin_axis; + begin_axis = actual_axes[0] < 0 ? actual_axes[0] + rank : actual_axes[0]; + for (int i = begin_axis + 1; i < rank; ++i) { + ShapePush(actual_axes, &actual_axes_size, i); + } + num_axes = rank - begin_axis; + keep_dims = false; + } + // reduce on all axes + if (num_axes == 0) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + // reduce on selected axes + return ReduceOnSelectedAxes(input, (size_t)num_axes, actual_axes, output, out_shape, out_shape_size, keep_dims); +} + +REG_INFER(Reduce, PrimType_ReduceFusion, ReduceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h new file mode 100644 index 00000000..c60eea7c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_INFER_H +#define MINDSPORE_NNACL_REDUCE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c new file mode 100644 index 00000000..6c9c0a3e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reduce_scatter_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceScatterParameter *param = (ReduceScatterParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + if (in_shape[0] % param->rank_size_ != 0) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] / param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(ReduceScatter, PrimType_ReduceScatter, ReduceScatterInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h new file mode 100644 index 00000000..80246f1c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reduce_scatter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H +#define MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reduce_scatter_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c new file mode 100644 index 00000000..8c28b9dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.c @@ -0,0 +1,221 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size_t *out_shape_size, int shape_size) { + int input_count = NNACLGetElementNum(inputs[0]); + int index = 0; + int size = 1; + for (int i = 0; i < shape_size; i++) { + if ((int)(data[i]) == -1) { + index = i; + } else if ((int)(data[i]) == 0) { + size *= inputs[0]->shape_[i]; + } else { + size *= data[i]; + } + ShapePush(out_shape, out_shape_size, data[i]); + } + if (size == 0) { + return NNACL_ERR; + } + if ((int)(data[index]) == -1) { + if (index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[index] = input_count / size; + } + return NNACL_OK; +} + +int CalNewShape(const TensorC *in_tensor, int *out_shape, size_t out_shape_size) { + int in_shape_size = 1; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + in_shape_size *= in_tensor->shape_[i]; + } + int64_t infer_index = -1; + int out_shape_size_new = 1; + for (size_t i = 0; i < out_shape_size; i++) { + if (out_shape[i] == -1) { + if (infer_index == -1) { + infer_index = (int64_t)(i); + } else { + return NNACL_ERR; + } + } else if (out_shape[i] < 0) { + return NNACL_ERR; + } else if (out_shape[i] == 0) { + if (NNACLGetElementNum(in_tensor) != 0) { + out_shape[i] = in_tensor->shape_[i]; + out_shape_size_new *= out_shape[i]; + } else { + out_shape_size_new = 0; + break; + } + } else { + out_shape_size_new *= out_shape[i]; + } + } + if (infer_index == -1 && out_shape_size_new != in_shape_size) { + return NNACL_ERR; + } + if (infer_index != -1) { + if (out_shape_size_new == 0) { + return NNACL_ERR; + } + if (infer_index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[infer_index] = in_shape_size / out_shape_size_new; + } + return NNACL_OK; +} + +int CalShapeByType(const TensorC *const *inputs, size_t shape_size, int *out_shape, size_t *out_shape_size) { + const TensorC *shape_tensor = inputs[1]; + if (shape_size == 0) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW((sizeof(int)), shape_size), NNACL_ERR); + int *data_int = (int *)malloc(sizeof(int) * shape_size); + if (data_int == NULL) { + return NNACL_ERR; + } + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = (int)data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + default: { + free(data_int); + return NNACL_ERR; + } + } + free(data_int); + return NNACL_OK; +} + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ReshapeParameter *param = (ReshapeParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (inputs_size == 2) { + const TensorC *shape_tensor = inputs[1]; + if (NNACLGetElementNum(input) == 1) { + if (shape_tensor->data_ == NULL || (shape_tensor->shape_size_ == 1 && shape_tensor->shape_[0] == 0)) { + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; + } + } + + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int calRet = CalShapeByType(inputs, shape_size, out_shape, &out_shape_size); + if (calRet != NNACL_OK) { + return calRet; + } + } else if (inputs_size == 1) { + if (param->shape_dim_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < param->shape_dim_; ++i) { + ShapePush(out_shape, &out_shape_size, param->shape_[i]); + } + } else { + return NNACL_ERR; + } + int ret = CalNewShape(inputs[0], out_shape, out_shape_size); + if (ret != NNACL_OK) { + return ret; + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Reshape, PrimType_Reshape, ReshapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h new file mode 100644 index 00000000..1f79b52f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/reshape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESHAPE_INFER_H +#define MINDSPORE_NNACL_RESHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c new file mode 100644 index 00000000..0867769e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/resize_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_1 = inputs[1]; + if (input_1->shape_size_ == 4) { + ShapeSet(output->shape_, &output->shape_size_, input_1->shape_, input_1->shape_size_); + } else if (input_1->shape_size_ == 1 && input_1->shape_[0] == 2 && input_1->data_type_ == kNumberTypeInt32) { + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + int32_t *data = (int32_t *)(input_1->data_); + + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ResizeGrad, PrimType_ResizeGrad, ResizeGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h new file mode 100644 index 00000000..87f2786a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ +#define MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32_grad/resize_grad.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c new file mode 100644 index 00000000..daf5e9de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/resize_infer.h" +#include +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) { + const TensorC *input = inputs[0]; + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + void *origin_data = shape_tensor->data_; + if (origin_data == NULL) { + return NNACL_INFER_INVALID; + } + switch (shape_size) { + case 2: + case 4: { + int height_index = 0; + int width_index = 1; + if (shape_size == 4) { + height_index = kNHWC_H; + width_index = kNHWC_W; + } + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int32_t *data = (int32_t *)(origin_data); + param->new_height_ = data[height_index]; + param->new_width_ = data[width_index]; + } else if (shape_tensor->data_type_ == kNumberTypeFloat32) { + float *data = (float *)(origin_data); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[height_index]), NNACLGetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[width_index]), NNACLGetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = round(data[height_index] * NNACLGetHeight(input)); + param->new_width_ = round(data[width_index] * NNACLGetWidth(input)); + } else if (shape_tensor->data_type_ == kNumberTypeFloat16) { + uint16_t *data = (uint16_t *)(shape_tensor->data_); + float scale_height = ShortToFloat32(data[height_index]); + float scale_width = ShortToFloat32(data[width_index]); + param->new_height_ = round(scale_height * NNACLGetHeight(input)); + param->new_width_ = round(scale_width * NNACLGetWidth(input)); + } + break; + } + case 1: { + // caffe zoom_factor + int scale; + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(origin_data); + scale = data[0]; + } else { + return NNACL_ERR; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetHeight(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetWidth(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = NNACLGetHeight(input) + (NNACLGetHeight(input) - 1) * (scale - 1); + param->new_width_ = NNACLGetWidth(input) + (NNACLGetWidth(input) - 1) * (scale - 1); + break; + } + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +int CalculateNewHeightAndWidth(const TensorC *const *inputs, size_t inputs_size, ResizeParameter *param) { + if (inputs_size == 2) { + return HandleTwoInputs(inputs, param); + } else if (inputs_size == 1) { + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + ResizeParameter *param = (ResizeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + int ret = CalculateNewHeightAndWidth(inputs, inputs_size, param); + if (ret == NNACL_OK) { + ShapePush(output_shape, &output_shape_size, param->new_height_); + ShapePush(output_shape, &output_shape_size, param->new_width_); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } + return ret; +} + +REG_INFER(Resize, PrimType_Resize, ResizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h new file mode 100644 index 00000000..c9549b05 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/resize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_INFER_H +#define MINDSPORE_NNACL_RESIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c new file mode 100644 index 00000000..a95fa84f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/rfft_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeComplex64; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + RfftParameter *param = (RfftParameter *)parameter; + if (input->shape_size_ < 1) { + return NNACL_ERR; + } + output->shape_[input->shape_size_ - 1] = param->fft_length_ / 2 + 1; + ShapePush(output->shape_, &(output->shape_size_), 2); + return NNACL_OK; +} + +REG_INFER(Rfft, PrimType_Rfft, RfftInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h new file mode 100644 index 00000000..c863ede9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/rfft_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RFFT_INFER_H +#define MINDSPORE_NNACL_RFFT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct RfftParameter { + OpParameter op_parameter_; + int fft_length_; +} RfftParameter; + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RFFT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c new file mode 100644 index 00000000..6ae35800 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *roi = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ROIPoolingParameter *param = (ROIPoolingParameter *)parameter; + output->shape_size_ = 4; + output->shape_[0] = roi->shape_[0]; + output->shape_[1] = param->pooledH_; + output->shape_[2] = param->pooledW_; + output->shape_[3] = NNACLGetChannel(input); + return NNACL_OK; +} + +REG_INFER(ROIPooling, PrimType_ROIPooling, ROIPoolingInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h new file mode 100644 index 00000000..4410ced9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/roi_pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ROI_POOLING_INFER_H +#define MINDSPORE_NNACL_ROI_POOLING_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/roi_pooling_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ROI_POOLING_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c new file mode 100644 index 00000000..536244b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *shape = inputs[THIRD_INPUT]; + if (shape->data_ == NULL) { + return NNACL_INFER_INVALID; + } + const TensorC *update = inputs[SECOND_INPUT]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, update); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *shape_data = (int *)(shape->data_); + NNACL_CHECK_TRUE_RET(NNACLGetElementNum(shape) <= MAX_SHAPE_SIZE, NNACL_ERR); + SetShapeArray(output, shape_data, (size_t)NNACLGetElementNum(shape)); + return NNACL_OK; +} + +REG_INFER(ScatterNd, PrimType_ScatterNd, ScatterNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h new file mode 100644 index 00000000..154c356c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c new file mode 100644 index 00000000..5ae93168 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/scatter_nd_update_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_x = inputs[0]; + const TensorC *indices = inputs[1]; + const TensorC *updates = inputs[2]; + TensorC *output = outputs[0]; + if (updates->data_type_ != input_x->data_type_ || + (indices->data_type_ != kNumberTypeInt32 && indices->data_type_ != kNumberTypeInt64)) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input_x); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (indices->shape_size_ < 2 || indices->shape_[indices->shape_size_ - 1] > input_x->shape_size_) { + return NNACL_ERR; + } + if (updates->shape_size_ != + (indices->shape_size_ - 1) + input_x->shape_size_ - indices->shape_[indices->shape_size_ - 1]) { + return NNACL_ERR; + } + for (int i = 0; i < updates->shape_size_; i++) { + if ((i < indices->shape_size_ - 1 && updates->shape_[i] != indices->shape_[i]) || + (i >= indices->shape_size_ - 1 && + updates->shape_[i] != + input_x->shape_[indices->shape_[indices->shape_size_ - 1] + i - indices->shape_size_ + 1])) { + return NNACL_ERR; + } + } + SetShapeArray(output, input_x->shape_, input_x->shape_size_); + return NNACL_OK; +} + +REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape) +REG_INFER(TensorScatterAdd, PrimType_TensorScatterAdd, ScatterNdUpdateInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h new file mode 100644 index 00000000..37b82748 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/scatter_nd_update_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c new file mode 100644 index 00000000..1b589d8f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/select_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensorlist_c_utils.h" + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2 * outputs_size + 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + if (input->data_type_ == kObjectTypeTensorType) { + TensorListC *input_tensorlist = (TensorListC *)(input); + TensorListC *output_tensorlist = (TensorListC *)(output); + output_tensorlist->element_shape_size_ = input_tensorlist->element_shape_size_; + for (size_t j = 0; j < input_tensorlist->element_shape_size_; j++) { + output_tensorlist->element_shape_[j] = input_tensorlist->element_shape_[j]; + } + output_tensorlist->max_elements_num_ = input_tensorlist->max_elements_num_; + output_tensorlist->tensors_data_type_ = input_tensorlist->tensors_data_type_; + output_tensorlist->element_num_ = input_tensorlist->element_num_; + + for (size_t j = 0; j < output_tensorlist->element_num_; j++) { + memcpy(&output_tensorlist->tensors_[j], &input_tensorlist->tensors_[j], sizeof(TensorC)); + } + } else { + SetShapeTensor(output, input); + } + } + return NNACL_OK; +} + +REG_INFER(Select, PrimType_Select, SelectInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h new file mode 100644 index 00000000..8575b19e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/select_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SELECT_INFER_H +#define MINDSPORE_NNACL_SELECT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SELECT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c new file mode 100644 index 00000000..b198cdf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sgd_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 6); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(SGD, PrimType_SGD, SgdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h new file mode 100644 index 00000000..8246a6a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sgd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SGD_INFER_H +#define MINDSPORE_NNACL_SGD_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SGD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c new file mode 100644 index 00000000..8d4dfba0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/shape_fusion_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CalculateOutput(const TensorC *in_tensor, const TensorC *matrix_tensor, TensorC *out_tensor, size_t input_len, + size_t origin_out_size) { + size_t out_size = out_tensor->shape_size_ == 0 ? 1 : (size_t)(out_tensor->shape_[0]); + if (out_size != origin_out_size && out_tensor->data_ != NULL) { + free(out_tensor->data_); + out_tensor->data_ = NULL; + } + size_t matrix_data_size = input_len * out_size * sizeof(float); + float *matrix_data = (float *)(malloc(matrix_data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matrix_data); + if (matrix_tensor->data_type_ == kNumberTypeFloat32 || matrix_tensor->data_type_ == kNumberTypeFloat) { + memcpy(matrix_data, matrix_tensor->data_, matrix_data_size); +#ifdef ENABLE_FP16 + } else if (matrix_tensor->data_type_ == kNumberTypeFloat16) { + for (size_t i = 0; i < input_len * out_size; i++) { + matrix_data[i] = (float)(((float16_t *)(matrix_tensor->data_))[i]); + } +#endif + } else { + free(matrix_data); + return NNACL_ERR; + } + if (out_tensor->data_ == NULL) { + out_tensor->data_ = malloc(out_size * sizeof(int)); + } + int *data = (int *)out_tensor->data_; + if (data == NULL) { + free(matrix_data); + return NNACL_ERR; + } + memset(data, 0, out_size * sizeof(int)); + for (size_t i = 0; i < out_size; i++) { + for (size_t j = 0; j < input_len - 1; j++) { + data[i] += (int)(in_tensor->shape_[j] * matrix_data[i * input_len + j]); + } + data[i] += (int)(matrix_data[i * input_len + input_len - 1]); + } + free(matrix_data); + return NNACL_OK; +} + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size + 1, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + size_t input_len = in_tensor->shape_size_ + 1; + for (size_t out_idx = 0; out_idx < outputs_size; out_idx++) { + TensorC *out_tensor = outputs[out_idx]; + size_t origin_out_size = + out_tensor->data_ == NULL ? 0 : (out_tensor->shape_size_ == 0 ? 1 : (size_t)out_tensor->shape_[0]); + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // calculate output tensor shape. + const TensorC *matrix_tensor = inputs[out_idx + 1]; + if (matrix_tensor->shape_size_ == 1) { + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 0; + } else { + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(matrix_tensor->shape_[0]); + } + int ret = CalculateOutput(in_tensor, matrix_tensor, out_tensor, input_len, origin_out_size); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(ShapeFusion, PrimType_Inner_ShapeFusion, ShapeFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h new file mode 100644 index 00000000..3c014100 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ +#define MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c new file mode 100644 index 00000000..8a2e3ff2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/shape_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(in_tensor->shape_size_); + return NNACL_OK; +} + +REG_INFER(Shape, PrimType_Shape, ShapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h new file mode 100644 index 00000000..27721d0b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/shape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_INFER_H +#define MINDSPORE_NNACL_SHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c new file mode 100644 index 00000000..8b9ec9bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/size_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(SizeOp, PrimType_Size, SizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h new file mode 100644 index 00000000..f1ccf7cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/size_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SIZE_INFER_H +#define MINDSPORE_NNACL_SIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c new file mode 100644 index 00000000..c041b3ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.c @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/slice_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size) { + // not support data_type of slice's begin and size is not int32 + if (inputs_size >= 2) { + if (inputs[1]->data_type_ != kNumberTypeInt32) { + return false; + } + } + if (inputs_size == 3) { + if (inputs[2]->data_type_ != kNumberTypeInt32) { + return false; + } + } + return true; +} + +int InitBeginAndSizeParam(const TensorC *const *inputs, int *begin, int *size, int param_length) { + /* init begin parameter */ + int slice_begin_size = NNACLGetElementNum(inputs[1]); + int *begin_ptr = (int *)(inputs[1]->data_); + if (slice_begin_size != param_length || begin_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_begin_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_begin_size; i++) { + begin[i] = begin_ptr[i]; + } + + /* init size parameter */ + int slice_size_size = NNACLGetElementNum(inputs[2]); + int *size_ptr = (int *)(inputs[2]->data_); + if (slice_size_size != param_length || size_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_size_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_size_size; i++) { + size[i] = size_ptr[i]; + } + return NNACL_OK; +} + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!CheckInputsDataType(inputs, inputs_size)) { + return NNACL_ERR; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + SliceParameter *param = (SliceParameter *)parameter; + int param_length = (int)(input->shape_size_); + output->shape_size_ = input->shape_size_; + int begin[MAX_SHAPE_SIZE]; + int size[MAX_SHAPE_SIZE]; + + ret = InitBeginAndSizeParam(inputs, begin, size, param_length); + if (ret != NNACL_OK) { + return ret; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (param->axis_[i] < 0) { + NNACL_CHECK_INT_ADD_NOT_OVERFLOW(param->axis_[i], (int)input->shape_size_, NNACL_PARAM_INVALID); + param->axis_[i] += (int)input->shape_size_; + } + NNACL_CHECK_TRUE_RET(param->axis_[i] >= 0 && param->axis_[i] < param_length, NNACL_PARAM_INVALID); + begin[param->axis_[i]] = begin[i]; + size[param->axis_[i]] = size[i]; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (size[i] < 0 && size[i] != -1) { + return NNACL_PARAM_INVALID; + } + if (begin[i] < 0) { + return NNACL_PARAM_INVALID; + } + if (input->shape_[i] < begin[i]) { + return NNACL_PARAM_INVALID; + } + if (size[i] > (input->shape_[i] - begin[i])) { + return NNACL_PARAM_INVALID; + } + + output->shape_[i] = size[i] < 0 ? input->shape_[i] - begin[i] : size[i]; + } + return NNACL_OK; +} + +REG_INFER(Slice, PrimType_SliceFusion, SliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h new file mode 100644 index 00000000..cdd0a09b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SLICE_INFER_H +#define MINDSPORE_NNACL_SLICE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SLICE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c new file mode 100644 index 00000000..eeb4ce59 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + out->shape_size_ = 2; + out->shape_[0] = in0->shape_[0]; + out->shape_[1] = 1; + SetDataTypeFormat(out, in0); + + if (1 < outputs_size) { + TensorC *grads = outputs[1]; + SetShapeTensor(grads, in0); + SetDataTypeFormat(grads, in0); + } + return NNACL_OK; +} + +REG_INFER(SoftmaxCrossEntropyWithLogits, PrimType_SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h new file mode 100644 index 00000000..ac407fb5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_cross_entropy_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_ENTROPY_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c new file mode 100644 index 00000000..12dd4f68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/softmax_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + // there is a model with an 8-dim input, which runs on ascend910. + if (input->shape_size_ > DIMENSION_8D) { + return NNACL_ERR; + } + + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Softmax, PrimType_Softmax, SoftMaxInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h new file mode 100644 index 00000000..556f46d4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c new file mode 100644 index 00000000..9e5d6f83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToBatchParameter *param = (SpaceToBatchParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + int *paddings = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = paddings[2]; + padding_right = paddings[3]; + block_w = block_shape[1]; + } + + NNACL_CHECK_ZERO_RETURN_ERR(block_shape[0]); + NNACL_CHECK_ZERO_RETURN_ERR(block_w); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_shape[0], block_w, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input->shape_[kNHWC_N], block_shape[0] * block_w, NNACL_ERR); + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * (block_shape[0] * block_w); + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + paddings[0] + paddings[1]) / block_shape[0]; + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToBatch, PrimType_SpaceToBatch, SpaceToBatchInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h new file mode 100644 index 00000000..d07d8ebf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c new file mode 100644 index 00000000..2415942a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.c @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_batch_nd_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SpaceSetOutputShapeFromParam(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + const SpaceToBatchParameter *param = (const SpaceToBatchParameter *)parameter; + const int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + const int *padding = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + } + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +int SpaceSetOutputShapeFromInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_ERR; + } + int *block_shape = (int *)(inputs[1]->data_); + int *padding = (int *)(inputs[2]->data_); + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (NNACLGetElementNum(inputs[1]) == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input->shape_size_; + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + output_shape[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + output_shape[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + output_shape[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + output_shape[kNHWC_C] = input->shape_[kNHWC_C]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_ERR; + } + outputs[0]->data_type_ = input->data_type_; + outputs[0]->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + int ret = SpaceSetOutputShapeFromParam(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + if (inputs_size == 3) { + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(SpaceToBatchND, PrimType_SpaceToBatchND, SpaceToBatchNdInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h new file mode 100644 index 00000000..e1c07688 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_batch_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c new file mode 100644 index 00000000..09b5aecf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/space_to_depth_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int32_t block_size = param->block_size_; + if (block_size == 0) { + return NNACL_ERR; + } + if (input->shape_[kNHWC_H] % block_size != 0 || input->shape_[kNHWC_H] == 0 || + input->shape_[kNHWC_W] % block_size != 0 || input->shape_[kNHWC_W] == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N]; + outputs[0]->shape_[kNHWC_H] = input->shape_[kNHWC_H] / block_size; + outputs[0]->shape_[kNHWC_W] = input->shape_[kNHWC_W] / block_size; + if (input->shape_[kNHWC_C] == 0 || block_size * block_size > INT_MAX / input->shape_[kNHWC_C]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C] * (block_size * block_size); + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToDepth, PrimType_SpaceToDepth, SpaceToDepthInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h new file mode 100644 index 00000000..88fe7f20 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/space_to_depth_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c new file mode 100644 index 00000000..c1369dd9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_fill_empty_rows_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, C4NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input0); + + const TensorC *input1 = inputs[1]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input1); + + TensorC *output2 = outputs[C2NUM]; + SetDataTypeFormat(output2, input0); + output2->data_type_ = kNumberTypeBool; + + if (outputs_size == C4NUM) { + TensorC *output3 = outputs[C3NUM]; + SetDataTypeFormat(output3, input0); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(SparseFillEmptyRows, PrimType_SparseFillEmptyRows, SparseFillEmptyRowsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h new file mode 100644 index 00000000..e6ce7882 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_fill_empty_rows_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H +#define MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c new file mode 100644 index 00000000..a00be2ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.c @@ -0,0 +1,53 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_reshape_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_indices_tensor = inputs[0]; + TensorC *out_indices_tensor = outputs[0]; + SetDataTypeFormat(out_indices_tensor, in_indices_tensor); + + const TensorC *in_out_shape_tensor = inputs[C2NUM]; + TensorC *out_shape_tensor = outputs[C1NUM]; + SetDataTypeFormat(out_shape_tensor, in_out_shape_tensor); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeArray(out_shape_tensor, in_out_shape_tensor->shape_, in_out_shape_tensor->shape_size_); + + int out_indices_shape[MAX_SHAPE_SIZE] = {0}; + out_indices_shape[0] = in_indices_tensor->shape_[0]; + size_t out_indices_shape_size = 1; + + for (int i = 0; i < in_out_shape_tensor->shape_size_; ++i) { + out_indices_shape[i + 1] = in_out_shape_tensor->shape_[i]; + out_indices_shape_size++; + } + SetShapeArray(out_indices_tensor, out_indices_shape, out_indices_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseReshape, PrimType_SparseReshape, SparseReshapeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h new file mode 100644 index 00000000..e594ffff --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_reshape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H +#define MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c new file mode 100644 index 00000000..cc0263a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_segment_sum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + return NNACL_OK; +} + +REG_INFER(SparseSegmentSum, PrimType_SparseSegmentSum, SparseSegmentSumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h new file mode 100644 index 00000000..4589c724 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_segment_sum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c new file mode 100644 index 00000000..84fb7d9e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" +#include "nnacl_c/infer/infer_register.h" + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SoftmaxCrossEntropyParameter *param = (SoftmaxCrossEntropyParameter *)parameter; + if (param->is_grad_ != 0) { + SetShapeTensor(out, in0); + SetDataTypeFormat(out, in0); + } else { + out->shape_size_ = 1; + out->shape_[0] = 1; + SetDataTypeFormat(out, in0); + } + + return NNACL_OK; +} + +REG_INFER(SparseSoftmaxCrossEntropyWithLogits, PrimType_SparseSoftmaxCrossEntropyWithLogits, + SparseSoftmaxCrossEntropyWithLogitsInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h new file mode 100644 index 00000000..396b50e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_softmax_cross_entropy_with_logits_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c new file mode 100644 index 00000000..1d017c73 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs_size < 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input1 = inputs[1]; + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *input1_data = (int *)(input1->data_); + int data_num = NNACLGetElementNum(input1); + if (input1_data == 0 || data_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < data_num; i++) { + ShapePush(output_shape, &output_shape_size, input1_data[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseToDense, PrimType_SparseToDense, SparseToDenseInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h new file mode 100644 index 00000000..9a521b55 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/sparse_to_dense_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c new file mode 100644 index 00000000..471af164 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/splice_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != DIMENSION_3D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SpliceParameter *param = (SpliceParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + int out_dim = param->output_dim_; + ShapeSet(output->shape_, &output->shape_size_, input->shape_, input->shape_size_); + + if (param->context_dim_ == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + if (param->forward_indexes_dim_ % param->context_dim_ != 0) { + return NNACL_PARAM_INVALID; + } + int out_size = param->forward_indexes_dim_ / param->context_dim_; + output->shape_[DIMENSION_1D] = out_size; + output->shape_[DIMENSION_2D] = out_dim; + return NNACL_OK; +} + +REG_INFER(Splice, PrimType_Splice, SpliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h new file mode 100644 index 00000000..312b1ee7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/splice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/splice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c new file mode 100644 index 00000000..1c4b79b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.c @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UpdateSplitSize(const TensorC *const *inputs, size_t inputs_size, SplitParameter *param) { + // get split size from the second input. + if (inputs_size == DIMENSION_2D && inputs[SECOND_INPUT]->data_ != NULL) { + if (inputs[SECOND_INPUT]->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + int split_count = 1; + for (size_t i = 0; i < inputs[SECOND_INPUT]->shape_size_; i++) { + split_count *= inputs[SECOND_INPUT]->shape_[i]; + } + param->split_count_ = split_count; + for (int i = 0; i < split_count; i++) { + param->split_sizes_[i] = ((int *)(inputs[SECOND_INPUT]->data_))[i]; + } + } + if (param->split_count_ == 0) { + const TensorC *input = inputs[0]; + int32_t split_chunk_size = UP_DIV(input->shape_[param->split_dim_], param->num_split_); + for (int i = 0; i < param->num_split_; ++i) { + if (i != param->num_split_ - 1) { + param->split_sizes_[i] = split_chunk_size; + } else { + param->split_sizes_[i] = input->shape_[param->split_dim_] - split_chunk_size * i; + } + } + } + return NNACL_OK; +} + +int SetSplitOutputShape(const TensorC *input, TensorC **outputs, SplitParameter *param) { + for (int i = 0; i < param->num_split_; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int split_dim_i = input->shape_[param->split_dim_]; + if (i == param->num_split_ - 1 && param->split_sizes_[i] == -1) { + if (param->num_split_ - 1 < 0) { + return NNACL_ERR; + } + for (int j = 0; j < param->num_split_ - 1; ++j) { + split_dim_i -= param->split_sizes_[j]; + } + param->split_sizes_[i] = split_dim_i; + } else { + split_dim_i = param->split_sizes_[i]; + } + NNACL_CHECK_TRUE_RET(split_dim_i >= 0 && split_dim_i <= input->shape_[param->split_dim_], NNACL_ERR); + output_shape[param->split_dim_] = split_dim_i; + SetShapeArray(outputs[i], output_shape, output_shape_size); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + SplitParameter *param = (SplitParameter *)parameter; + + int num_split = param->num_split_ == 0 ? (int)(outputs_size) : param->num_split_; + if (num_split == 0) { + return NNACL_ERR; + } + param->num_split_ = num_split; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int split_dim = param->split_dim_ < 0 ? ((int)(input->shape_size_)) + param->split_dim_ : param->split_dim_; + if (split_dim >= (int)(input->shape_size_) || split_dim < 0) { + return NNACL_ERR; + } + param->split_dim_ = split_dim; + if ((int)(outputs_size) != num_split) { + return NNACL_ERR; + } + + int ret = UpdateSplitSize(inputs, inputs_size, param); + if (ret != NNACL_OK) { + return ret; + } + ret = SetSplitOutputShape(input, outputs, param); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(Split, PrimType_Split, SplitInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h new file mode 100644 index 00000000..acb8f5e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_INFER_H +#define MINDSPORE_NNACL_SPLIT_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c new file mode 100644 index 00000000..39c1463e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_reduce_concat_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/split_parameter.h" + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->format_ = in_tensor->format_; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + out_tensor->shape_[i] = in_tensor->shape_[i]; + } + SplitParameter *param = (SplitParameter *)parameter; + out_tensor->shape_[param->split_dim_] = param->num_split_; + out_tensor->shape_size_ = in_tensor->shape_size_; + return NNACL_OK; +} + +REG_INFER(SplitReduceConcatFusion, PrimType_Inner_SplitReduceConcatFusion, SplitReduceConcatFusionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h new file mode 100644 index 00000000..80cfc65c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H +#define MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c new file mode 100644 index 00000000..0588fb19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/split_with_over_lap_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; + + int split_dim = param->split_dim_; + int number_split = param->num_split_; + if (outputs_size != (size_t)number_split) { + return NNACL_ERR; + } + + int ratio[SPLIT_MAX_SLICE_NUM]; + int extend_top[SPLIT_MAX_SLICE_NUM]; + int extend_bottom[SPLIT_MAX_SLICE_NUM]; + for (int i = 0; i < number_split; ++i) { + ratio[i] = param->ratio_[i]; + extend_top[i] = param->extend_top_[i]; + extend_bottom[i] = param->extend_bottom_[i]; + } + + const int *input_shape = input->shape_; + int split_dim_size = input_shape[split_dim]; + int total_block_count = 0; + for (int i = 0; i < number_split; i++) { + total_block_count += ratio[i]; + } + + int borders[MAX_SHAPE_SIZE]; + borders[0] = 0; + int visited_block = 0; + for (int i = 0; i < number_split - 1; i++) { + visited_block += ratio[i]; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(split_dim_size, visited_block) || total_block_count == 0, NNACL_ERR); + int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + borders[i + 1] = cur_border; + } + borders[number_split] = split_dim_size; + + for (int i = 0; i < number_split; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + for (int dim = 0; dim < input->shape_size_; dim++) { + if (dim == split_dim) { + int splited_size = borders[i + 1] - borders[i]; + splited_size += (extend_top[i] + extend_bottom[i]); + output_shape[dim] = splited_size; + } else { + output_shape[dim] = input_shape[dim]; + } + } + SetShapeArray(outputs[i], output_shape, input->shape_size_); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h new file mode 100644 index 00000000..96b9ef40 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/split_with_over_lap_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H +#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c new file mode 100644 index 00000000..b663436d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/squeeze_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, kInputSize1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SqueezeParameter *param = (SqueezeParameter *)parameter; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs_size == kInputSize1) { + NNACL_CHECK_TRUE_RET(inputs[1]->data_type_ == kNumberTypeInt32 || inputs[1]->data_type_ == kNumberTypeInt, + NNACL_PARAM_INVALID); + int *axis_data = (int *)(inputs[1]->data_); + NNACL_CHECK_TRUE_RET(axis_data != NULL, NNACL_PARAM_INVALID); + param->axis_size_ = NNACLGetElementNum(inputs[1]); + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = *(axis_data + i); + } + } + if (param->axis_size_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = param->axis_[i] >= 0 ? param->axis_[i] : param->axis_[i] + (int)input->shape_size_; + } + + if (param->axis_size_ == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + size_t axisIdx = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + if (axisIdx < param->axis_size_ && param->axis_[axisIdx] == (int)(i)) { + if (input->shape_[i] != 1) return NNACL_PARAM_INVALID; + axisIdx++; + continue; + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } + SetShapeArray(outputs[0], out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Squeeze, PrimType_Squeeze, SqueezeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h new file mode 100644 index 00000000..ef2a773a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/squeeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUEEZE_INFER_H +#define MINDSPORE_NNACL_SQUEEZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/squeeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SQUEEZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c new file mode 100644 index 00000000..d8772575 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/stack_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1) { + return NNACL_PARAM_INVALID; + } + if (inputs_size < 1) { + return NNACL_PARAM_INVALID; + } + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + StackParameter *param = (StackParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int axis = param->axis_ < 0 ? (int)(param->axis_) + (int)(input->shape_size_) + 1 : param->axis_; + if (axis < 0 || axis > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != input->shape_size_) { + return NNACL_PARAM_INVALID; + } + for (size_t j = 0; j < input->shape_size_; ++j) { + if (inputs[i]->shape_[j] != input->shape_[j]) { + return NNACL_PARAM_INVALID; + } + } + if (inputs[i]->data_type_ != input->data_type_) { + return NNACL_PARAM_INVALID; + } + } + int insert_ret = ShapeInsert(output_shape, &output_shape_size, axis, inputs_size); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Stack, PrimType_Stack, StackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h new file mode 100644 index 00000000..aec77cce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/stack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STACK_INFER_H +#define MINDSPORE_NNACL_STACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c new file mode 100644 index 00000000..17e28625 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/strided_slice_grad_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + if (NNACLGetElementNum(inputs[2]) > MAX_SHAPE_SIZE) { + return false; + } + if (NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[3]) && + NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[4])) { + return false; + } + return true; // note: the original code is ndim_ <= in_shape_size +} + +void ApplyBeginEndEllipsisMask(size_t ndim, int *begins, const uint32_t *const begins_mask, int *ends, + const uint32_t *const ends_mask, const uint32_t *const ellipsis_mask, + const int *const in_shape) { + for (size_t i = 0; i < ndim; i++) { + if (begins_mask[i]) { + begins[i] = 0; + } + if (ends_mask[i]) { + ends[i] = in_shape[i]; + } + } + for (size_t i = 0; i < ndim; i++) { + if (ellipsis_mask[i]) { + begins[i] = 0; + ends[i] = in_shape[i]; + break; + } + } +} + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + bool inferflag = InferFlag(inputs, inputs_size); + + int in_shape_[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (inferflag) { + ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_); + } + int begins_[MAX_SHAPE_SIZE] = {0}; + size_t begins_size = 0; + int ends_[MAX_SHAPE_SIZE] = {0}; + size_t ends_size = 0; + int strides_[MAX_SHAPE_SIZE] = {0}; + size_t strides_size = 0; + + if (!StridedSliceCheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // input order: dy, shapex, begins, ends, strides. + const TensorC *begin_tensor = inputs[2]; + int *begin_data = (int *)(begin_tensor->data_); + int *end_data = (int *)(inputs[3]->data_); + int *stride_data = (int *)(inputs[4]->data_); + + size_t ndim_ = (size_t)NNACLGetElementNum(begin_tensor); + if (ndim_ > MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < ndim_; ++i) { + ShapePush(begins_, &begins_size, begin_data[i]); + ShapePush(ends_, &ends_size, end_data[i]); + ShapePush(strides_, &strides_size, stride_data[i]); + } + + // set all mask to original input shape + uint32_t begins_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ends_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ellipsis_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t new_axis_mask_[MAX_SHAPE_SIZE] = {0}; + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + for (size_t i = 0; i < ndim_; i++) { + begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + } + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + for (size_t i = 0; i < ndim_; ++i) { + param->begins_[i] = begins_[i]; + param->ends_[i] = ends_[i]; + param->strides_[i] = strides_[i]; + } + ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_); + // ApplyNewAxisMask; + for (size_t i = 0; i < ndim_; i++) { + if (new_axis_mask_[i]) { + ndim_ += 1; + int ret = ShapeInsert(in_shape_, &in_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + begins_[i] = 0; + ends_[i] = 1; + strides_[i] = 1; + + ShapePush(begins_, &begins_size, 0); + ShapePush(ends_, &ends_size, in_shape_[ndim_ - 1]); + ShapePush(strides_, &strides_size, 1); + + begins_mask_[i] = false; + ends_mask_[i] = false; + ellipsis_mask_[i] = false; + } + } + ApplyBeginEndEllipsisMask(ndim_, begins_, begins_mask_, ends_, ends_mask_, ellipsis_mask_, in_shape_); + if (!inferflag) { + return NNACL_OK; + } + int output_size = inputs[1]->shape_[0]; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (inputs[1]->data_ == NULL) { + return NNACL_ERR; + } + + if (output_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < output_size; i++) { + ShapePush(output_shape, &output_shape_size, ((int *)(inputs[1]->data_))[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(StridedSliceGrad, PrimType_StridedSliceGrad, StridedSliceGradInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h new file mode 100644 index 00000000..caba55ba --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c new file mode 100644 index 00000000..89cbf0eb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.c @@ -0,0 +1,483 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/strided_slice_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" + +const size_t kStridedSliceOutputNum = 1; +const size_t kStridedSliceInputNum = 1; +const size_t kStridedSliceMultiInputNumMin = 3; +const size_t kStridedSliceMultiInputNumMax = 5; + +typedef struct StridedSliceTransferBuffer { + int ndim_; + + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int begins_mask_[MAX_SHAPE_SIZE]; + int ends_mask_[MAX_SHAPE_SIZE]; + int ellipsis_mask_[MAX_SHAPE_SIZE]; + int new_axis_mask_[MAX_SHAPE_SIZE]; + int shrink_axis_mask_[MAX_SHAPE_SIZE]; + + size_t begins_size_; + size_t ends_size_; + size_t strides_size_; + size_t ellipsis_mask_size_; + size_t new_axis_mask_size_; + size_t shrink_axis_mask_size_; +} StridedSliceTransferBuffer; + +bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + return true; +} + +int HandleAxesCheckNull(const TensorC *input_tensor, const TensorC *begin_tensor, int *begin_data, + const TensorC *end_tensor, int *end_data) { + if (input_tensor == NULL || begin_tensor == NULL || end_tensor == NULL || begin_data == NULL || end_data == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int HandleAxesInputNotExist(const TensorC *const *inputs, struct StridedSliceTransferBuffer *transfer_buffer) { + const TensorC *begin_tensor = inputs[1]; + const TensorC *end_tensor = inputs[2]; + const TensorC *stride_tensor = inputs[3]; + int ret = GetInt32DataFromTensor(begin_tensor, transfer_buffer->begins_, &transfer_buffer->begins_size_); + if (ret != NNACL_OK) { + return ret; + } + transfer_buffer->ndim_ = NNACLGetElementNum(begin_tensor); + ret = GetInt32DataFromTensor(end_tensor, transfer_buffer->ends_, &transfer_buffer->ends_size_); + if (ret != NNACL_OK) { + return ret; + } + ret = GetInt32DataFromTensor(stride_tensor, transfer_buffer->strides_, &transfer_buffer->strides_size_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int GenerateAxes(const TensorC *axes_tensor, int *axes, int num, int ndim) { + int *axes_data = NULL; + if (NNACLGetElementNum(axes_tensor) != 0) { + if (NNACLGetElementNum(axes_tensor) != num) { + return NNACL_ERR; + } + axes_data = (int *)(axes_tensor->data_); + if (axes_data == NULL) { + return NNACL_NULL_PTR; + } + } + if (axes_data == NULL) { + for (int i = 0; i < num; ++i) { + axes[i] = i; + } + } else { + for (int i = 0; i < num; i++) { + axes[i] = axes_data[i]; + } + for (int i = 0; i < num; ++i) { + if (axes[i] < 0) { + axes[i] += ndim; + } + } + } + return NNACL_OK; +} + +int HandleAxesInputExist(const TensorC *const *inputs, int *ndim, int *in_shape, int *begins, int *strides, int *ends) { + const TensorC *input_tensor = inputs[0]; + const TensorC *begin_tensor = inputs[1]; + int begin_data[MAX_SHAPE_SIZE]; + size_t begin_data_size; + int ret = GetInt32DataFromTensor(begin_tensor, begin_data, &begin_data_size); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *end_tensor = inputs[2]; + int end_data[MAX_SHAPE_SIZE]; + size_t end_data_size; + ret = GetInt32DataFromTensor(end_tensor, end_data, &end_data_size); + if (ret != NNACL_OK) { + return ret; + } + + int handle_check_ret = HandleAxesCheckNull(input_tensor, begin_tensor, begin_data, end_tensor, end_data); + if (handle_check_ret != NNACL_OK) { + return handle_check_ret; + } + + // when input contains axes, begins, ends, strides will be expand to the same length as input rank + *ndim = (int)(input_tensor->shape_size_); + int begin_ndim = NNACLGetElementNum(begin_tensor); + + int *stride_data = NULL; + const TensorC *stride_tensor = inputs[4]; + int stride_data_num = NNACLGetElementNum(stride_tensor); + if (stride_data_num != 0) { + NNACL_CHECK_TRUE_RET(stride_data_num == begin_ndim, NNACL_ERR); + stride_data = (int *)(stride_tensor->data_); + } + + const TensorC *axes_tensor = inputs[3]; + int axes[MAX_SHAPE_SIZE] = {0}; + ret = GenerateAxes(axes_tensor, axes, begin_ndim, *ndim); + if (ret != NNACL_OK) { + return ret; + } + + if (*ndim > MAX_SHAPE_SIZE || *ndim < 0) { + return NNACL_ERR; + } + for (int i = 0; i < *ndim; i++) { + in_shape[i] = 0; + begins[i] = 0; + strides[i] = 0; + } + for (int i = 0; i < *ndim; ++i) { + in_shape[i] = input_tensor->shape_[i]; + } + for (int i = 0; i < *ndim; ++i) { + int axes_it = 0; + if (begin_ndim > MAX_SHAPE_SIZE || begin_ndim < 0) { + return NNACL_ERR; + } + for (int j = 0; j < begin_ndim; j++) { + if (axes[j] == i) { + axes_it = j; + break; + } else { + axes_it++; + } + } + if (axes_it != begin_ndim) { + int axis = axes_it; + if (begin_data[axis] > input_tensor->shape_[i] - 1) { + begins[i] = begin_data[axis]; + } else { + begins[i] = imax(imin(begin_data[axis], input_tensor->shape_[i] - 1), -input_tensor->shape_[i]); + } + // ends exceed limit will be set to limit + ends[i] = imax(imin(end_data[axis], input_tensor->shape_[i]), -input_tensor->shape_[i] - 1); + if (stride_data == NULL) { + return NNACL_ERR; + } + strides[i] = stride_data[axis]; + } else { + begins[i] = 0; + ends[i] = input_tensor->shape_[i]; + strides[i] = 1; + } + } + return NNACL_OK; +} + +int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (outputs_size != kStridedSliceOutputNum) { + return NNACL_PARAM_INVALID; + } + if (inputs_size != kStridedSliceInputNum && + !(inputs_size <= kStridedSliceMultiInputNumMax && inputs_size >= kStridedSliceMultiInputNumMin)) { + return NNACL_PARAM_INVALID; + } + if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs_size >= kStridedSliceMultiInputNumMin) { + bool begins_type_ok = + (inputs[C1NUM]->data_type_ == kNumberTypeInt32) || (inputs[C1NUM]->data_type_ == kNumberTypeInt64); + bool ends_type_ok = + (inputs[C2NUM]->data_type_ == kNumberTypeInt32) || (inputs[C2NUM]->data_type_ == kNumberTypeInt64); + if (!(begins_type_ok && ends_type_ok)) { + return NNACL_PARAM_INVALID; + } + } + return NNACL_OK; +} + +void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, const StridedSliceParameter *param) { + for (unsigned i = 0; i < (unsigned)(size_t)(transfer_buffer->ndim_); i++) { + transfer_buffer->begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + transfer_buffer->ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + transfer_buffer->ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + transfer_buffer->new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + transfer_buffer->shrink_axis_mask_[i] = (unsigned)(param->shrinkAxisMask_) & (1 << i); + } +} + +int ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape, + size_t *out_shape_size) { + for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) { + if (transfer_buffer->new_axis_mask_[i]) { + if (*out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(in_shape, out_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = 1; + transfer_buffer->strides_[i] = 1; + + ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0); + ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape[(size_t)(transfer_buffer->ndim_) - 1]); + ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1); + + transfer_buffer->begins_mask_[i] = false; + transfer_buffer->ends_mask_[i] = false; + transfer_buffer->ellipsis_mask_[i] = false; + transfer_buffer->shrink_axis_mask_[i] = false; + } + } + return NNACL_OK; +} + +void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->begins_mask_[i]) { + transfer_buffer->begins_[i] = transfer_buffer->strides_[i] > 0 ? 0 : -1; + } + } +} + +int ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->ends_mask_[i]) { + if ((size_t)i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->ends_[i] = transfer_buffer->strides_[i] > 0 ? in_shape[i] : -1 - in_shape[i]; + } + } + return NNACL_OK; +} + +int ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) { + if (transfer_buffer->ellipsis_mask_[i]) { + if (i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = in_shape[i]; + break; + } + } + return NNACL_OK; +} + +int TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t max_shape_size, + size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->begins_size_; i++) { + if (i >= max_shape_size) { + return NNACL_ERR; + } + if (transfer_buffer->begins_[i] < 0) { + transfer_buffer->begins_[i] += in_shape[i]; + } + if (transfer_buffer->ends_[i] < 0) { + transfer_buffer->ends_[i] += in_shape[i]; + } + if (i < in_shape_size) { + if (transfer_buffer->begins_[i] < 0 || transfer_buffer->begins_[i] > in_shape[i]) { + return NNACL_ERR; + } + if ((transfer_buffer->ends_[i] < 0 && transfer_buffer->ends_[i] != -1) || + transfer_buffer->ends_[i] > in_shape[i]) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) { + int old_out_shape[MAX_SHAPE_SIZE] = {0}; + size_t old_out_shape_size = 0; + ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size); + *output_shape_size = 0; + for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) { + if (transfer_buffer->shrink_axis_mask_[i]) { + transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1; + transfer_buffer->strides_[i] = 1; + } else { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } + } + for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } +} + +int TransferBuffer2Param(const StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, + const int *in_shape, size_t in_shape_size) { + if (transfer_buffer->ndim_ >= (int)(in_shape_size) || param->in_shape_length_ >= (int)(in_shape_size)) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer->ndim_; i++) { + param->begins_[i] = transfer_buffer->begins_[i]; + param->ends_[i] = transfer_buffer->ends_[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = transfer_buffer->strides_[i]; + } + + for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) { + param->begins_[i] = 0; + param->ends_[i] = in_shape[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = 1; + } + return NNACL_OK; +} + +void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->begins_size_ = 0; + transfer_buffer->ends_size_ = 0; + transfer_buffer->strides_size_ = 0; + transfer_buffer->ellipsis_mask_size_ = 0; + transfer_buffer->new_axis_mask_size_ = 0; + transfer_buffer->shrink_axis_mask_size_ = 0; +} + +void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->ellipsis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->new_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->shrink_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->begins_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->ends_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->strides_size_ = (size_t)(transfer_buffer->ndim_); +} + +// note: begin, end, stride length are equal, but may less than rank of input +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + + StridedSliceTransferBuffer transfer_buffer; + InitStridedSliceTransferBuffer(&transfer_buffer); + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + + transfer_buffer.ndim_ = 0; + if (inputs_size == kStridedSliceInputNum) { + transfer_buffer.ndim_ = (int)(in_shape_size); + if (transfer_buffer.ndim_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]); + ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]); + ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]); + } + } + if (!CheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 4) { + int ret = HandleAxesInputNotExist(inputs, &transfer_buffer); + if (ret != NNACL_OK) { + return ret; + } + } + + if (inputs_size == 5) { + int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape, transfer_buffer.begins_, + transfer_buffer.strides_, transfer_buffer.ends_); + if (ret != NNACL_OK) { + return ret; + } + } + + // set all mask to original input shape + SetMaskSize(&transfer_buffer); + Bit2Vector(&transfer_buffer, param); + int ret = ApplyNewAxisMask(&transfer_buffer, param, in_shape, &in_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + // update parameter with new input shape + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + + ApplyBeginMask(&transfer_buffer); + ret = ApplyEndMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + ret = ApplyEllipsisMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, in_shape, in_shape_size); + ret = TransIndexToPositive(&transfer_buffer, in_shape, MAX_SHAPE_SIZE, input->shape_size_); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + if (transfer_buffer.strides_[i] == 0 || in_shape[i] < transfer_buffer.ends_[i]) { + return NNACL_ERR; + } + output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] + + (transfer_buffer.strides_[i] < 0 ? 1 : -1)) / + transfer_buffer.strides_[i]; + output_shape[i] = output_shape[i] > 0 ? output_shape[i] : 0; + } + ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + ret = TransferBuffer2Param(&transfer_buffer, param, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(StridedSlice, PrimType_StridedSlice, StridedSliceInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h new file mode 100644 index 00000000..492f5b2a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/strided_slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c new file mode 100644 index 00000000..889d2daf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int string_num = *((const int32_t *)(input->data_)); + + int res = (string_num == 0 ? 1 : string_num); + output0->shape_size_ = 1; + output0->shape_[0] = res; + output1->shape_size_ = 1; + output1->shape_[0] = res; + return NNACL_OK; +} + +REG_INFER(CustomExtractFeatures, PrimType_CustomExtractFeatures, CustomExtractFeaturesInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h new file mode 100644 index 00000000..86149229 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_extract_features_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c new file mode 100644 index 00000000..4cbb57ee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input) < 1) { + return NNACL_ERR; + } + if (input->data_type_ != kNumberTypeUInt32 && input->data_type_ != kObjectTypeString) { + return NNACL_ERR; + } + int string_num = *((const int32_t *)(input->data_)); // also look custom_extract_features + output->shape_size_ = 1; + output->shape_[0] = (string_num == 0 ? 1 : string_num); + return NNACL_OK; +} + +REG_INFER(CustomNormalize, PrimType_CustomNormalize, CustomNormalizeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h new file mode 100644 index 00000000..cd45bd5b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_normalize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c new file mode 100644 index 00000000..f0524c96 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + CustomPredictParameter *param = (CustomPredictParameter *)parameter; + output0->shape_size_ = 1; + output0->shape_[0] = param->output_num; + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->shape_size_ = 1; + output1->shape_[0] = param->output_num; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + return NNACL_OK; +} + +REG_INFER(CustomPredict, PrimType_CustomPredict, CustomPredictInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h new file mode 100644 index 00000000..1d6410f6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/custom_predict_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CustomPredictParameter { + OpParameter op_parameter_; + int output_num; +} CustomPredictParameter; + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c new file mode 100644 index 00000000..c26d5c21 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *values = inputs[2]; + if (input == NULL || values == NULL) { + return NNACL_NULL_PTR; + } + + TensorC *output = outputs[0]; + TensorC *hits = outputs[1]; + + output->data_type_ = values->data_type_; + output->format_ = input->format_; + hits->shape_size_ = 1; + hits->shape_[0] = NNACLGetDimensionSize(input, 0); + hits->data_type_ = kNumberTypeUInt8; + hits->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(HashtableLookup, PrimType_HashtableLookup, HashtableLoopupInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h new file mode 100644 index 00000000..9879e0e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/hashtable_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c new file mode 100644 index 00000000..0422e800 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.c @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_hash = inputs[0]; + if (in_hash->shape_size_ != 2 || NNACLGetDimensionSize(in_hash, 1) > 32) { + return NNACL_ERR; + } + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = Format_NHWC; + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + LshProjectionParameter *param = (LshProjectionParameter *)parameter; + switch (param->lsh_type_) { + case LshProjectionType_SPARSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0)); + break; + case LshProjectionType_DENSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0) * NNACLGetDimensionSize(in_hash, 1)); + break; + default: + return NNACL_ERR; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(LshProjection, PrimType_LshProjection, LshProjectionInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h new file mode 100644 index 00000000..eb98f7e8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/lsh_projection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/lsh_projection_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c new file mode 100644 index 00000000..ebc45475 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/string/skip_gram_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(SkipGram, PrimType_SkipGram, SkipGramInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h new file mode 100644 index 00000000..5fd2f176 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/string/skip_gram_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c new file mode 100644 index 00000000..e191b2b4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.c @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/tile_infer.h" +#include +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tile_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +void TileParamCaffe2Tflite(TileParameter *param, size_t out_shape_size) { + if (param->dims_size_ != 0) { + int multiples_size_tmp[5] = {0}; + NNACL_CHECK_TRUE_RET_VOID(out_shape_size <= 5); + for (size_t i = 0; i < out_shape_size; i++) { + multiples_size_tmp[i] = 1; + } + for (size_t i = 0; i < param->dims_size_; i++) { + if (i >= MAX_SHAPE_SIZE) { + return; + } + multiples_size_tmp[param->dims_[i]] = param->multiples_[i]; + } + for (size_t i = 0; i < 5; i++) { + param->multiples_[i] = multiples_size_tmp[i]; + } + } +} + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + TileParameter *param = (TileParameter *)parameter; + + size_t multiples_size = 0; + int input1_shape_size = inputs[1]->shape_size_; + if (input1_shape_size > (int)(input->shape_size_) || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + NNACL_CHECK_TRUE_RET(input1_shape_size <= MAX_SHAPE_SIZE, NNACL_ERR); + int data_num = NNACLGetElementNum(inputs[1]); + multiples_size = (size_t)(data_num); + if (inputs[1]->data_type_ != kNumberTypeInt && inputs[1]->data_type_ != kNumberTypeInt32) { + return NNACL_INPUT_TENSOR_ERROR; + } + int *input1_data = inputs[1]->data_; + if (input1_data == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(data_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int i = 0; i < data_num; i++) { + param->multiples_[i] = input1_data[i]; + } + + int *dims = param->dims_; + size_t dims_size = param->dims_size_; + if (dims_size == 0) { + int dim_num = NNACLGetElementNum(inputs[1]); + NNACL_CHECK_TRUE_RET(dim_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int dim = 0; dim < dim_num; ++dim) { + ShapePush(dims, &dims_size, dim); + } + param->dims_size_ = dims_size; + } + NNACL_CHECK_TRUE_RET(multiples_size == dims_size, NNACL_ERR); + for (size_t i = 0; i < input->shape_size_; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + for (size_t i = 0; i < dims_size; ++i) { + if (dims[i] >= MAX_SHAPE_SIZE || input->shape_[dims[i]] == 0) { + return NNACL_ERR; + } + if (input->shape_[dims[i]] != 0 && param->multiples_[i] > INT_MAX / input->shape_[dims[i]]) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input->shape_[dims[i]], (param->multiples_[i])), NNACL_ERR); + out_shape[dims[i]] = input->shape_[dims[i]] * (param->multiples_[i]); + } + // change caffe param format to tflite + TileParamCaffe2Tflite(param, out_shape_size); + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Tile, PrimType_TileFusion, TileInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h new file mode 100644 index 00000000..ae1aedbc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/tile_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TILE_INFER_H +#define MINDSPORE_NNACL_TILE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/base/tile_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TILE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c new file mode 100644 index 00000000..fac77961 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.c @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/topk_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output0, input); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_k_tensor = inputs[1]; + if (input_k_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + TopkParameter *param = (TopkParameter *)parameter; + param->k_ = ((int32_t *)input_k_tensor->data_)[0]; + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + if (out_shape_size < 1) { + return NNACL_ERR; + } + if (param->axis_ < 0) { + param->axis_ += (int)out_shape_size; + } + if (param->axis_ < 0 || (size_t)param->axis_ >= out_shape_size) { + return NNACL_ERR; + } + out_shape[(size_t)param->axis_] = param->k_; + + SetShapeArray(output0, out_shape, out_shape_size); + SetShapeArray(output1, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(TopK, PrimType_TopKFusion, TopKInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h new file mode 100644 index 00000000..a08f06b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/topk_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TOPK_INFER_H +#define MINDSPORE_NNACL_TOPK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/topk_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TOPK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c new file mode 100644 index 00000000..7bfa6077 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.c @@ -0,0 +1,137 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/transpose_infer.h" +#include "nnacl_c/infer/infer_register.h" + +bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const int size) { + for (int i = 0; i < size; ++i) { + if (perm[i] != perm_transformat[i]) { + return false; + } + } + return true; +} + +int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, const int *perm, size_t perm_size, + int *out_shape) { + // set output shape + size_t in_shape_size = input->shape_size_; + output->shape_size_ = in_shape_size; + if (perm_size == 0) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[in_shape_size - i - 1] = input->shape_[i]; + } + } else if (perm_size != in_shape_size) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[i] = input->shape_[i]; + } + } else { + output->shape_size_ = perm_size; + for (size_t i = 0; i < perm_size; ++i) { + if (perm[i] >= input->shape_size_) { + return NNACL_ERR; + } else { + out_shape[i] = input->shape_[perm[i]]; + } + } + } + return NNACL_OK; +} + +int GetAndCheckPerm(const TensorC *perm_tensor, const int perms_num, int *perm, size_t *perm_size) { + if (perms_num >= MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + + int ret = GetInt32DataFromTensor(perm_tensor, perm, perm_size); + if (ret != NNACL_OK) { + return ret; + } + for (size_t i = 0; i < *perm_size; i++) { + NNACL_CHECK_TRUE_RET(perm[i] < perms_num, NNACL_ERR); + } + return NNACL_OK; +} + +void Handle4DPerm(const TensorC *input, TensorC *output, int *perm, size_t *perm_size) { + const int nchw2nhwc[4] = {Index0, Index2, Index3, Index1}; + const int nhwc2nchw[4] = {Index0, Index3, Index1, Index2}; + const int trans3d[3] = {Index0, Index2, Index1}; + if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, PERM_NUM_FOUR)) { + output->format_ = Format_NHWC; + } else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) && + CheckPermTransFormat(perm, nhwc2nchw, PERM_NUM_FOUR)) { + output->format_ = Format_NCHW; + } + // though the perm is 4d in default, the input can be a 3d tensor. The op implementation must be adapted to this. + if (input->shape_size_ == DIMENSION_3D) { + ShapeSet(perm, perm_size, trans3d, DIMENSION_3D); + } +} + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + const TensorC *perm_tensor = inputs[1]; + if (perm_tensor == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(perm_tensor->shape_size_ == 1, NNACL_INFER_INVALID); + const int perms_num = perm_tensor->shape_[0]; + if (perms_num != 0 && perm_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + TransposeParameter *transpose_param = (TransposeParameter *)parameter; + transpose_param->perm_size_ = perms_num; + int perm[MAX_TRANSPOSE_DIM_SIZE] = {0}; + size_t perm_size = 0; + int ret = GetAndCheckPerm(perm_tensor, perms_num, perm, &perm_size); + if (ret != NNACL_OK) { + return ret; + } + + if (perms_num == PERM_NUM_FOUR) { + Handle4DPerm(input, output, perm, &perm_size); + } + int kPermIndex0 = 0; + int kPermIndex2 = 2; + if (perms_num == PERM_NUM_THREE && perm[0] == kPermIndex0 && perm[1] == kPermIndex2) { + output->format_ = input->format_ == Format_NCHW ? Format_NHWC : Format_NCHW; + } + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // set output shape + int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; + SetOutputShape(perms_num, input, output, perm, perm_size, out_shape); + SetShapeArray(output, out_shape, output->shape_size_); + return NNACL_OK; +} + +REG_INFER(Transpose, PrimType_Transpose, TransposeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h new file mode 100644 index 00000000..2557fcbd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/transpose_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_TRANSPOSE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRANSPOSE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c new file mode 100644 index 00000000..faa45c8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/triu_tril_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + const size_t triul_input_min_size = 1; + const size_t triul_output_size = 1; + if (inputs_size < triul_input_min_size || outputs_size != triul_output_size) { + return NNACL_ERR; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Triu, PrimType_Triu, TriuTrilInferShape) +REG_INFER(Tril, PrimType_Tril, TriuTrilInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h new file mode 100644 index 00000000..5dd85d45 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/triu_tril_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRIU_TRIL_INFER_H +#define MINDSPORE_NNACL_TRIU_TRIL_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRIU_TRIL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c new file mode 100644 index 00000000..4614c7ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/uniform_real_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE || input_num < 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = (size_t)(input_num); + for (int i = 0; i < input_num; i++) { + output_shape[i] = input_data[i]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(UniformReal, PrimType_UniformReal, UniformRealInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h new file mode 100644 index 00000000..d3aad6a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/uniform_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIFORM_REAL_INFER_H +#define MINDSPORE_NNACL_UNIFORM_REAL_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIFORM_REAL_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c new file mode 100644 index 00000000..ab24aebe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unique_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + SetDataTypeFormat(output0, input0); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input0); + SetShapeTensor(output1, input0); + return NNACL_OK; +} + +REG_INFER(Unique, PrimType_Unique, UniqueInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h new file mode 100644 index 00000000..b97e37de --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unique_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIQUE_INFER_H +#define MINDSPORE_NNACL_UNIQUE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIQUE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c new file mode 100644 index 00000000..4fdfe03f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *out = outputs[0]; + const TensorC *x = inputs[0]; + const TensorC *segment_id = inputs[1]; + if (inputs[2]->data_ == NULL || + (inputs[2]->data_type_ != kNumberTypeInt && inputs[2]->data_type_ != kNumberTypeInt32)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int num_segments = *(int *)(inputs[2]->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, num_segments); + for (int index = (int)(segment_id->shape_size_); index < (int)(x->shape_size_); index++) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, x->shape_[index]); + } + SetShapeArray(out, output_shape, output_shape_size); + SetDataTypeFormat(out, x); + return NNACL_OK; +} + +REG_INFER(UnsortedSegmentSum, PrimType_UnsortedSegmentSum, UnsortedSegmentSumInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h new file mode 100644 index 00000000..f6332f5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsorted_segment_sum_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct UnsortedSegmentSumParameter { + OpParameter op_parameter_; + int segments_num_; +} UnsortedSegmentSumParameter; + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c new file mode 100644 index 00000000..3bbf1d28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.c @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + UnSqueezeParameter *param = (UnSqueezeParameter *)parameter; + int in_rank = (int)(input->shape_size_); + int dim_rank = param->num_dim_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (dim_rank == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + int sz = in_rank + dim_rank; + size_t in_itr = 0; + size_t ax_itr = 0; + if (sz < 0) { + return NNACL_ERR; + } + for (int i = 0; i < sz; i++) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] == (int)(i)) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] + sz == i) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else { + if (in_itr >= input->shape_size_) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[in_itr]); + in_itr++; + } + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Unsqueeze, PrimType_Unsqueeze, UnsqueezeInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h new file mode 100644 index 00000000..fa7b96e4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unsqueeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSQUEEZE_INFER_H +#define MINDSPORE_NNACL_UNSQUEEZE_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSQUEEZE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c new file mode 100644 index 00000000..b604c1b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/unstack_infer.h" +#include "nnacl_c/infer/infer_register.h" + +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + UnstackParameter *param = (UnstackParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)(input->shape_size_) : param->axis_; + if (axis < 0 || axis >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; ++i) { + if (i != (size_t)(axis)) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, input->shape_[i]); + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + SetShapeArray(outputs[i], output_shape, output_shape_size); + } + return NNACL_OK; +} + +REG_INFER(Unstack, PrimType_Unstack, UnstackInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h new file mode 100644 index 00000000..386447e8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/unstack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSTACK_INFER_H +#define MINDSPORE_NNACL_UNSTACK_INFER_H + +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSTACK_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c new file mode 100644 index 00000000..06a10cbe --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/infer/where_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/infer/broadcast_to_infer.h" + +int WhereBroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(output); + + // Need to dynamically allocate at runtime. + if (inputs_size == 1) { + output->data_type_ = kNumberTypeInt32; + output->format_ = input0->format_; + return NNACL_INFER_INVALID; + } + + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input1 = inputs[1]; + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input1); + NNACL_CHECK_NULL_RETURN_ERR(input2); + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int in_shape2[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape0_size = input0->shape_size_; + size_t input_shape1_size = input1->shape_size_; + size_t input_shape2_size = input2->shape_size_; + const int *input_shape0 = input0->shape_; + const int *input_shape1 = input1->shape_; + const int *input_shape2 = input2->shape_; + int ndim = (int)input_shape0_size; + bool has_broad_cast_1 = false; + bool has_broad_cast_2 = false; + if (WhereBroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, + in_shape1, output_shape, &has_broad_cast_1) != NNACL_OK) { + return NNACL_ERR; + } + if (WhereBroadCastInferShape(ndim, input_shape2_size, output_shape, input_shape2, &ndim, in_shape0, in_shape2, + output_shape, &has_broad_cast_2) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, output_shape, ndim); + return NNACL_OK; +} + +REG_INFER(Where, PrimType_Where, WhereInferShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h new file mode 100644 index 00000000..6dadfc79 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/infer/where_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_WHERE_INFER_H +#define MINDSPORE_NNACL_WHERE_INFER_H + +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_WHERE_INFER_H diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h new file mode 100644 index 00000000..daa82e28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/instance_norm_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INSTANCE_NORM_PARAMETER_H_ +#define NNACL_INSTANCE_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct InstanceNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + // shape correlative + int batch_; + int channel_; + int inner_size_; +} InstanceNormParameter; + +#endif // NNACL_INSTANCE_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c new file mode 100644 index 00000000..a1f54b43 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.c @@ -0,0 +1,531 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/avx/common_utils.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_ARM +void AddInt8InputRounding(int32x4_t *in1, int32x4_t *in2, int32x4_t *in3, int32x4_t *in4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *in1 = vmulq_s32(*in1, left_vec); + *in2 = vmulq_s32(*in2, left_vec); + *in3 = vmulq_s32(*in3, left_vec); + *in4 = vmulq_s32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = vqrdmulhq_n_s32(*in1, multiplier); + *in2 = vqrdmulhq_n_s32(*in2, multiplier); + *in3 = vqrdmulhq_n_s32(*in3, multiplier); + *in4 = vqrdmulhq_n_s32(*in4, multiplier); + + // Apply right shift + *in1 = vqaddq_s32(*in1, vshrq_n_s32(vandq_s32(*in1, right_vec), 31)); + *in2 = vqaddq_s32(*in2, vshrq_n_s32(vandq_s32(*in2, right_vec), 31)); + *in3 = vqaddq_s32(*in3, vshrq_n_s32(vandq_s32(*in3, right_vec), 31)); + *in4 = vqaddq_s32(*in4, vshrq_n_s32(vandq_s32(*in4, right_vec), 31)); + + *in1 = vrshlq_s32(*in1, right_vec); + *in2 = vrshlq_s32(*in2, right_vec); + *in3 = vrshlq_s32(*in3, right_vec); + *in4 = vrshlq_s32(*in4, right_vec); +} + +void AddInt8OutputRounding(int32x4_t *out1, int32x4_t *out2, int32x4_t *out3, int32x4_t *out4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *out1 = vshlq_s32(*out1, left_vec); + *out2 = vshlq_s32(*out2, left_vec); + *out3 = vshlq_s32(*out3, left_vec); + *out4 = vshlq_s32(*out4, left_vec); + + // Apply the fixed-point part of the multiplier. + *out1 = vqrdmulhq_n_s32(*out1, multiplier); + *out2 = vqrdmulhq_n_s32(*out2, multiplier); + *out3 = vqrdmulhq_n_s32(*out3, multiplier); + *out4 = vqrdmulhq_n_s32(*out4, multiplier); + + // Apply right shift + *out1 = vqaddq_s32(*out1, vshrq_n_s32(vandq_s32(*out1, right_vec), 31)); + *out2 = vqaddq_s32(*out2, vshrq_n_s32(vandq_s32(*out2, right_vec), 31)); + *out3 = vqaddq_s32(*out3, vshrq_n_s32(vandq_s32(*out3, right_vec), 31)); + *out4 = vqaddq_s32(*out4, vshrq_n_s32(vandq_s32(*out4, right_vec), 31)); + + *out1 = vrshlq_s32(*out1, right_vec); + *out2 = vrshlq_s32(*out2, right_vec); + *out3 = vrshlq_s32(*out3, right_vec); + *out4 = vrshlq_s32(*out4, right_vec); +} +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params) { + int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + int index = 0; +#ifdef ENABLE_ARM + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_); + const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); + const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); + + const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_); + const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + for (; index <= size - 16; index += 16) { + const int8x16_t in0_src = vld1q_s8(input0 + index); + const int8x16_t in1_src = vld1q_s8(input1 + index); + + const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src)); + const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src)); + const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src)); + const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src)); + + const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec); + const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec); + const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec); + const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec); + + int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low)); + int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low)); + int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high)); + int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high)); + int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low)); + int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low)); + int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); + int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); + + AddInt8InputRounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, in0_right_vec, params->in0_args_.multiplier_); + AddInt8InputRounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, in1_right_vec, params->in1_args_.multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(in0_1, in1_1); + int32x4_t out2 = vaddq_s32(in0_2, in1_2); + int32x4_t out3 = vaddq_s32(in0_3, in1_3); + int32x4_t out4 = vaddq_s32(in0_4, in1_4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + int index = 0; + +#ifdef ENABLE_ARM + /* const value init */ + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_); + const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift); + const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift); + + const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_); + const int32x4_t ele_right_vec = vdupq_n_s32(-ele_args->right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + /* deal with const node */ + const int8x16_t ele_src = vdupq_n_s8(element_in); + const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src)); + const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src)); + const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec); + const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec); + int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low)); + int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); + int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); + int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); + + AddInt8InputRounding(&ele1, &ele2, &ele3, &ele4, ele_left_vec, ele_right_vec, ele_args->multiplier_); + + for (; index <= size - 16; index += 16) { + const int8x16_t ptr_src = vld1q_s8(ptr_in + index); + + const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src)); + const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src)); + + const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec); + const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec); + + int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low)); + int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low)); + int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); + int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); + + AddInt8InputRounding(&ptr1, &ptr2, &ptr3, &ptr4, ptr_left_vec, ptr_right_vec, ptr_args->multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(ptr1, ele1); + int32x4_t out2 = vaddq_s32(ptr2, ele2); + int32x4_t out3 = vaddq_s32(ptr3, ele3); + int32x4_t out4 = vaddq_s32(ptr4, ele4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; + const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; + const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = in0[i] + in1[i]; + } + return NNACL_OK; +} + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param) { + TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param); + return ElementAddInt8(tile_in0, tile_in1, out, size); +} + +#ifdef ENABLE_AVX +void AddInt8Rounding(__m128i *in1, __m128i *in2, __m128i *in3, __m128i *in4, const __m128i left_vec, + const int32_t right_shift, const __m128i multiplier) { + // Apply left shift + *in1 = _mm_mullo_epi32(*in1, left_vec); + *in2 = _mm_mullo_epi32(*in2, left_vec); + *in3 = _mm_mullo_epi32(*in3, left_vec); + *in4 = _mm_mullo_epi32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = _mm_qrdmulh_epi32(*in1, multiplier); + *in2 = _mm_qrdmulh_epi32(*in2, multiplier); + *in3 = _mm_qrdmulh_epi32(*in3, multiplier); + *in4 = _mm_qrdmulh_epi32(*in4, multiplier); + + // Apply right shift + int32_t in1_remainder_mask = (1ll << (right_shift)) - 1; + int32_t in1_remainder_threshold = in1_remainder_mask >> 1; + const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); + const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); + + const __m128i in1_remainder = + _mm_add_epi32(_mm_and_si128(*in1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in1)); + *in1 = _mm_sub_epi32(_mm_rshr_epi32(*in1, right_shift), _mm_cmpgt_epi32(in1_remainder, vin1_remainder_threshold)); + + const __m128i in2_remainder = + _mm_add_epi32(_mm_and_si128(*in2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in2)); + *in2 = _mm_sub_epi32(_mm_rshr_epi32(*in2, right_shift), _mm_cmpgt_epi32(in2_remainder, vin1_remainder_threshold)); + + const __m128i in3_remainder = + _mm_add_epi32(_mm_and_si128(*in3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in3)); + *in3 = _mm_sub_epi32(_mm_rshr_epi32(*in3, right_shift), _mm_cmpgt_epi32(in3_remainder, vin1_remainder_threshold)); + + const __m128i in4_remainder = + _mm_add_epi32(_mm_and_si128(*in4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in4)); + *in4 = _mm_sub_epi32(_mm_rshr_epi32(*in4, right_shift), _mm_cmpgt_epi32(in4_remainder, vin1_remainder_threshold)); +} + +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params) { + const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(params->in0_args_.zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(params->in1_args_.zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(input0 + index)); + const __m128i in1_src = _mm_loadu_si128((__m128i *)(input1 + index)); + + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + // input0: ptr_in + // input1: element_in + // load quant parameters of input0 and input1 + const int in0_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(ptr_args->zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(ele_args->zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + + // input1 can be processed once because it is const + const __m128i in1_src = _mm_set1_epi8(element_in); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(ptr_in + index)); + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift; + const int32_t in1_left = (element_in + ele_args->zp_) * in1_left_shift; + const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h new file mode 100644 index 00000000..0ecd8154 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/add_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ADD_INT8_H_ +#define NNACL_ADD_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/int8/arithmetic_int8.h" + +typedef struct AddQuantQrgs { + int32_t zp_; + int32_t left_shift_; + int32_t right_shift_; + int32_t multiplier_; +} AddQuantQrgs; + +typedef struct AddQuantParameter { + int left_shift_; + int32_t min_; + int32_t max_; + + AddQuantQrgs in0_args_; + AddQuantQrgs in1_args_; + + int32_t out_zp_; + int32_t out_left_shift_; + int32_t out_right_shift_; + int32_t out_multiplier_; +} AddQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params); + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size); + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param); + +#ifdef ENABLE_AVX +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params); + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); +#endif +#ifdef __cplusplus +} +#endif + +#endif // NNACL_ADD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c new file mode 100644 index 00000000..0d331880 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.c @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/arg_min_max_int8.h" +#include + +void CalcParameter(const int32_t *shape, int dims_number, int axis, int32_t *pre_axis_count, int32_t *axis_count, + int32_t *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void SetOutputValue(float value, int32_t index, int8_t *output1, int8_t *output2, int offset, + float output_inverse_scale, float output_zp, bool out_value) { + if (output2 != NULL) { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + output2[offset] = value * output_inverse_scale + output_zp; + } else { + if (out_value) { + output1[offset] = value * output_inverse_scale + output_zp; + } else { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + } + } +} + +void DoArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < pre_axis_count; ++i) { + int output_offset = i * after_axis_count; + int input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + if (!param->get_max_) { + value = FLT_MAX; + } + int32_t index = 0; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j] * in_quant_arg->scale_ + bias; + if (param->get_max_) { + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } else { + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + } + SetOutputValue(value, index, output1, output2, output_offset + j, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + CalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + DoArgMinMaxQuant(input, output1, output2, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg, + out_quant_arg); + return; +} + +int ArgCompareAscInt8(const void *a, const void *b) { + return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_; +} + +int ArgCompareDescInt8(const void *a, const void *b) { + return ((ArgElement *)b)->data_.f_data_ - ((ArgElement *)a)->data_.f_data_; +} + +int8_t GetInt8Output(float real_out, float output_inverse_scale, int32_t output_zp) { + return real_out * output_inverse_scale + output_zp; +} + +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + int offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = (uint32_t)j; + param->arg_elements_[j].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int j = 0; j < param->topk_; ++j) { + int out_offset = j * param->out_strides_[0] + i; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + int offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = (size_t)k; + param->arg_elements_[k].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int k = 0; k < param->topk_; ++k) { + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } +} + +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + int offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} + +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + int offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim2_offset + l; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h new file mode 100644 index 00000000..d4ccfc5a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arg_min_max_int8.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARG_MIN_MAX_INT8_H_ +#define NNACL_INT8_ARG_MIN_MAX_INT8_H_ + +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARG_MIN_MAX_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c new file mode 100644 index 00000000..f08592a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.c @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/arithmetic_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/errorcode.h" + +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(int8_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionInt8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionInt8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionInt8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +#define ACCURACY_DATA 0.00000001 + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = true; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = false; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = false; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = true; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real < in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real <= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real > in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real >= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +#undef ACCURACY_DATA diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h new file mode 100644 index 00000000..8d1d2be0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARITHMETIC_INT8_H_ +#define NNACL_INT8_ARITHMETIC_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param); + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c new file mode 100644 index 00000000..5e3fc901 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/arithmetic_self_int8.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl_c/int8/common_func_int8.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(floorf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(round(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(ceil(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(fabsf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(sinf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(cosf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(logf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + int32_t output_tmp = round(sqrtf(input_f32) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + int32_t output_tmp = round(1.f / (sqrtf(input_f32) * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input, int32x4_t left_shift_out_vec, int32x4_t output_multiplier_vec, + ArithSelfQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input, scaled_input); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_args_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void SquareInt8NEON(const int8_t *input_data, int8_t *output_data, int64_t element_size, ArithSelfQuantArg para, + int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)para.shift_left_); + + for (; (*index) <= element_size - 8; (*index) += 8) { + int16x8_t input_val = LoadAndAddOffset(input_data, *index, para.in_args_.zp_); + int32x4_t input_low = vmovl_s16(vget_low_s16(input_val)); + int32x4_t input_high = vmovl_s16(vget_high_s16(input_val)); + + int16x4_t sum_low = ClacSumHalfWord(input_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = ClacSumHalfWord(input_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + int32_t in_zp = para.in_args_.zp_; + int32_t out_zp = para.out_args_.zp_; + + int index = 0; +#ifdef ENABLE_NEON + SquareInt8NEON(input, output, element_size, para, &index); +#endif + for (; index < element_size; index++) { + const int32_t input_val = input[index] + in_zp; + int32_t output_tmp = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input_val * input_val * (1 << para.shift_left_), para.output_multiplier_), + para.shift_right_); + output_tmp += out_zp; + if (output_tmp > para.output_activation_max_) { + output[index] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[index] = para.output_activation_min_; + } else { + output[index] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(((float)(!(bool)(input[i] * in_scale + bias))) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (output_tmp); + } + } + return NNACL_OK; +} + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (fabs(input_f32) <= FLT_EPSILON) { + return NNACL_ERR; + } + int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h new file mode 100644 index 00000000..21e7d75b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/arithmetic_self_int8.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_ARITHMETIC_SELF_INT8_H_ +#define NNACL_INT8_ARITHMETIC_SELF_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_SELF_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c new file mode 100644 index 00000000..e963505a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.c @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/batch_to_space_int8.h" + +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} + +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int64_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h new file mode 100644 index 00000000..7fe6bdd2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batch_to_space_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#define NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg); +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c new file mode 100644 index 00000000..ea89a205 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/batchnorm_int8.h" +#include +#include "nnacl_c/batchnorm_parameter.h" + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel) { + int unit_st = task_id * unit; + int unit_end = MSMIN((task_id + 1) * unit, units); + for (int u = unit_st; u < unit_end; u++) { + for (int c = 0; c < channel; c++) { + int32_t output_tmp = round(input_ptr[u * channel + c] * alpha_ptr[c] + beta_ptr[c]); + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output_ptr[u * channel + c] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h new file mode 100644 index 00000000..fc3b1f25 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/batchnorm_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_BATCHNORM_H_ +#define NNACL_INT8_BATCHNORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/batchnorm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCHNORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c new file mode 100644 index 00000000..b9424784 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, + size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, + int32_t left_shift, int32_t right_shift, int32_t zp, int size) { + if (size == 0) { + return; + } + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c4div = c / size, c4mod = c % size; + int src_index = c4div * in_plane_stride + r * size + c4mod; + int dst_index = r * out_oc_stride + c; + int32_t value = in[src_index]; + if (bias != NULL) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} + +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi) { +/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ +#ifndef ENABLE_ARM64 + PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi, + left_shift, right_shift, zp, C4NUM); +#else + size_t oc4div = oc / C4NUM * C4NUM; + size_t oc4res = oc % C4NUM; + PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift, + right_shift, zp, mini, maxi); +#endif + return; +} + +#ifdef ENABLE_ARM +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset) { + int8x8_t input_s8 = vld1_s8(data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + return vaddq_s16(input_s16, vdupq_n_s16(offset)); +} + +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec) { + int32x4_t shifted_input = vmulq_s32(input, left_shift_result_vec); + shifted_input = vqrdmulhq_s32(shifted_input, input_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(shifted_input, right_shift_vec), 31); + return vrshlq_s32(vqaddq_s32(shifted_input, fixup), right_shift_vec); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h new file mode 100644 index 00000000..fba3af4c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/common_func_int8.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_COMMON_FUNC_H_ +#define NNACL_INT8_COMMON_FUNC_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi); +#ifdef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp); +void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max); +void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, + int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, + size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, const int8_t *in_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max); +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max); +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset); +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec); +#endif + +#ifdef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, size_t per_channel); +#endif + +#ifdef ENABLE_ARM64 +void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, + size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, + int32_t zp, int32_t mini, int32_t maxi); +void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, + size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t acc_min, size_t acc_max, + size_t per_channel); +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif /* NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c new file mode 100644 index 00000000..6b1d8bf0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/concat_int8.h" +#include +#include +#include +#include "nnacl_c/concat_parameter.h" + +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape) { + float output_scale = para->quant_arg_.out_args_.scale_; + const float output_inverse_scale = 1.f / output_scale; + int out_copy_size = output_shape[axis] * after_axis_size; + QuantArg *input_quant = para->quant_arg_.in_args_; + int output_zp = para->quant_arg_.out_args_.zp_; + int8_t max_int8 = para->quant_arg_.output_activation_max_; + int8_t min_int8 = para->quant_arg_.output_activation_min_; + int64_t start = task_id * count_unit; + int64_t end = start + real_dst_count; + output += start * out_copy_size; + + for (int k = start; k < end; k++) { + for (int i = 0; i < input_num; i++) { + const int32_t *input_shape = input_shapes[i]; + int64_t in_copy_size = input_shape[axis] * after_axis_size; + const int8_t *input_ptr = inputs[i] + k * in_copy_size; + if (fabs(input_quant[i].scale_ - output_scale) <= FLT_EPSILON && input_quant[i].zp_ == output_zp) { + memcpy(output, input_ptr, in_copy_size); + } else { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + for (int j = 0; j < in_copy_size; j++) { + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + output_tmp = output_tmp > min_int8 ? output_tmp : min_int8; + output_tmp = output_tmp < max_int8 ? output_tmp : max_int8; + output[j] = (int8_t)output_tmp; + } + } + output += in_copy_size; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h new file mode 100644 index 00000000..e80e8831 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/concat_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONCAT_INT8_H_ +#define NNACL_INT8_CONCAT_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONCAT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c new file mode 100644 index 00000000..9a51d27f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv1x1_int8.h" + +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, + filter_zp); + return; +} + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h new file mode 100644 index 00000000..fd8868e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv1x1_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV1X1_INT8_H_ +#define NNACL_INT8_CONV1X1_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp); +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV1X1_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c new file mode 100644 index 00000000..6ecb53e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.c @@ -0,0 +1,902 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv3x3_int8.h" + +void Conv3x3Int8InputUnit(const int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + const int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + const int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + const int16_t *local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, int oc_start, + const ConvParameter *conv_param) { + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier; + int32x4_t ls; + int32x4_t rs; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + out_multiplier = vld1q_s32(quant_multiplier + oc_start); + ls = vld1q_s32(left_shift + oc_start); + rs = vld1q_s32(right_shift + oc_start); + } else { + out_multiplier = vdupq_n_s32(quant_multiplier[0]); + ls = vdupq_n_s32(left_shift[0]); + rs = vdupq_n_s32(right_shift[0]); + } + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + int32x4_t carry = vandq_s32(d00, rs); + carry = vshrq_n_s32(carry, 31); + d00 = vqaddq_s32(d00, carry); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + carry = vandq_s32(d01, rs); + carry = vshrq_n_s32(carry, 31); + d01 = vqaddq_s32(d01, carry); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + carry = vandq_s32(d10, rs); + carry = vshrq_n_s32(carry, 31); + d10 = vqaddq_s32(d10, carry); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + carry = vandq_s32(d11, rs); + carry = vshrq_n_s32(carry, 31); + d11 = vqaddq_s32(d11, carry); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (int8_t)d00[0]; + (output_data + 1)[0] = (int8_t)d00[1]; + (output_data + 2)[0] = (int8_t)d00[2]; + (output_data + 3)[0] = (int8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (int8_t)d01[0]; + *(output_data + 5) = (int8_t)d01[1]; + *(output_data + 6) = (int8_t)d01[2]; + *(output_data + 7) = (int8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (int8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; + } + } +#else + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + int oc_index = oc_start + i; + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } else { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } +#endif +} + +void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + const int oc4 = UP_DIV(output_channel, C4NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, + conv_param); + } + } +} + +void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_l_; + int pad_h = conv_param->pad_u_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.input_quant_args_[0].zp_; + const int ic8 = UP_DIV(input_channel, C8NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (size_t)(real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = (size_t)ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, (size_t)oc4 * 4 * 16 * sizeof(int32_t)); +#else + const int input_unit_square = 16; + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 convolution 3x3 +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param) { + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + NNACL_CHECK_ZERO_RETURN(TILE_NUM); + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + const int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; + + for (int batch = 0; batch < conv_param->input_batch_; batch++) { + int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, conv_param->output_channel_, ic8, real_cal_num); + + Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h new file mode 100644 index 00000000..794c4ea6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv3x3_int8.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c new file mode 100644 index 00000000..7d83429c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.c @@ -0,0 +1,825 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv_depthwise_int8.h" +#include +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/common_func_int8.h" + +/*conv depthwise int8 begin*/ +#ifndef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + const int16_t input = input_ptr[c] - input_zp; + *output_ptr++ += input * weight_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, bool per_channel) { + if (per_channel) { + // support perchannel + for (int w = 0; w < output_w; w++) { + int channel4 = 0; +#ifdef ENABLE_ARM + channel4 = channel / 4 * 4; + ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, + acc_max); +#endif + for (int c = channel4; c < channel; c++) { + buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + buffer[c] += output_zp; + buffer[c] = MSMAX(buffer[c], acc_min); + buffer[c] = MSMIN(buffer[c], acc_max); + dst[c] = (buffer[c]); + } + buffer += channel; + dst += channel; + } + } else { + int num_pixels = output_w * channel; + int align_num = 0; +#ifdef ENABLE_ARM + align_num = num_pixels / 4 * 4; + ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, + acc_max); +#endif + for (int i = align_num; i < num_pixels; i++) { + buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + buffer[i] += output_zp; + buffer[i] = MSMAX(buffer[i], acc_min); + buffer[i] = MSMIN(buffer[i], acc_max); + dst[i] = (buffer[i]); + } + } +} + +void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id) { + int step_h = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int start_h = step_h * task_id; + int end_h = MSMIN(start_h + step_h, conv_param->output_h_); + + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + + int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = start_h; oh < end_h; oh++) { + int8_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + // init acc + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(row_buffer + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(int32_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const int8_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const int16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + int32_t *acc_w = row_buffer + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const int8_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwInt8Row(acc_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step, intput_zp); + weight_kh += conv_param->output_channel_; + } + } + // post func, acc int32 -> dst int8 + ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } + } +} +/*conv depthwise int8 end*/ + +/*conv depthwise 3x3 int8 begin*/ +void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *conv_param, int block_input_h, + int block_input_w) { + for (int h = 0; h < block_input_h; h++) { + const int8_t *src = input; + for (int w = 0; w < block_input_w; w++) { + memcpy(buffer, src, 64); + src += conv_param->input_channel_; + buffer += 64; + } + input += conv_param->input_w_ * conv_param->input_channel_; + } +} + +void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size, + int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (int w = 0; w < output_w; w++) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + int8_t *output_tmp = output; + const int8_t *src_kh = buffer; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < 3; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < 3; kw++) { + for (int c = 0; c < 8; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp) * weight_kw[c]; + } + src_kw += col_size; + weight_kw += channel; + } + src_kh += row_size; + weight_kh += 3 * channel; + } + if (per_channel) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } else { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } + output += channel; + buffer += col_size * stride; + } +} + +void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int start_c, + int end_c, int col_size, int row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (; start_c <= end_c - 8; start_c += 8) { +#ifdef ENABLE_ARM64 + if (stride == 1) { + ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } else { + ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } + +#else + ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, stride, per_channel); +#endif + output += 8; + buffer += 8; + weight += 8; + bias += 8; + if (per_channel) { + out_multiplier += 8; + left_shift += 8; + right_shift += 8; + } + } +} + +void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const int16_t *weight, const int32_t *bias, + const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, + int block_input_h, int block_input_w) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + const int ih_offset = 64 * block_input_w; + int w = start_w; + if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { + for (; w <= end_w - block_output_w; w += block_output_w) { + int8_t *output_ptr = output; + const int8_t *input_ptr = input; + const int16_t *weight_ptr = weight; + const int32_t *bias_ptr = bias; + int32_t *out_multiplier_ptr = out_multiplier; + int32_t *left_shift_ptr = left_shift; + int32_t *right_shift_ptr = right_shift; + int c = 0; + for (; c <= conv_param->output_channel_ - 64; c += 64) { + ConvDw3x3Int8InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); + ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, + block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, left_shift_ptr, + right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output_ptr += 64; + input_ptr += 64; + weight_ptr += 64; + bias_ptr += 64; + if (filter_per_channel) { + out_multiplier_ptr += 64; + left_shift_ptr += 64; + right_shift_ptr += 64; + } + } + // left channel + ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, + conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, + conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, + left_shift_ptr, right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output += block_output_w * conv_param->input_channel_; + input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; + } + } + // left width + int left_width = end_w - w; + if (left_width > 0) { + ConvDw3x3Int8Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, + conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, + left_width, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + conv_param->stride_h_, filter_per_channel); + } +} + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + int output_h = sliding->bottom_ - sliding->top_; + int step_oh = UP_DIV(output_h, conv_param->thread_num_); + int start_oh = step_oh * task_id + sliding->top_; + int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); + int start_ow = sliding->left_; + int end_ow = sliding->right_; + + const int block_output_h = 1; + int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; + const int block_input_h = 3; + int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; + + for (int b = 0; b < conv_param->output_batch_; b++) { + int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + + start_ih * conv_param->input_w_ * conv_param->input_channel_ + + start_iw * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + + start_oh * conv_param->output_w_ * conv_param->output_channel_ + + start_ow * conv_param->output_channel_; + + for (int oh = start_oh; oh < end_oh; oh++) { + ConvDw3x3Int8Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, + block_output_w, block_input_h, block_input_w); + src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; + dst += conv_param->output_w_ * conv_param->output_channel_; + } + } +} + +#ifndef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t acc_min, const int32_t acc_max, bool per_channel) { + for (int c = 0; c < channel; c += 8) { + int tmp_buffer[8]; + for (int i = 0; i < 8; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += (src_kw[c + i] - in_zp) * weight_kw[c + i]; + } + src_kw += in_kw_step; + weight_kw += channel; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += 3 * channel; + } // kernel_h loop + if (per_channel) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[i]), out_multiplier[i]), + -right_shift[i]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + left_shift += 8; + right_shift += 8; + out_multiplier += 8; + } else { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + } + dst += 8; + } +} +#endif + +#ifndef ENABLE_ARM64 +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} +#endif + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + int input_row_size = conv_param->input_w_ * conv_param->input_channel_; + int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; + int output_row_size = conv_param->output_w_ * conv_param->output_channel_; + int in_kh_step = sliding->in_kh_step_; + int in_kw_step = sliding->in_kw_step_; + + // top + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *input_batch = + input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *output_batch = + output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + + const int8_t *input = input_batch; + const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; + int8_t *output = output_batch; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; + weight = weight_data + weight_row_size; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + + // left + input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + output_row_size; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // right + input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data; + output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // bottom + input = input_batch + (conv_param->input_h_ - 2) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + (conv_param->output_h_ - 1) * output_row_size; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; + weight = weight_data; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } +} +/*conv depthwise 3x3 int8 end*/ + +/*conv depthwise sliding window perchannel int8 begin*/ +void ConvDwInt8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, const int8_t *input_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst[c] = (tmp_buffer[c]); + } +} + +void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, int bottom, + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + const int8_t *in_zp, const int32_t *out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *acc_min, + const int32_t *acc_max) { + int8_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const int8_t *src_h = src + ih * sliding->in_h_step_; + + int8_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const int8_t *src_w = src_h + iw * sliding->block_channel_; + + const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + + ConvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + int8_t *dst_h = dst; + const int8_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int8_t *dst_w = dst_h; + const int8_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const int8_t *src_kh = src_w; + const int16_t *weight_kh = weight; + + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add bias relu + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst_w[c] = (tmp_buffer[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const int8_t *src = input_data; + int8_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const int8_t *src_data = src + oc * C8NUM; + int8_t *dst_data = dst + oc * C8NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C8NUM; + + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; + int32_t *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; + int32_t *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; + const int8_t *in_zp = input_zp + oc * C8NUM; + const int32_t *out_zp = output_zp + oc * C8NUM; + + ConvDwInt8Border(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc8 +} +/*conv depthwise sliding window perchannel int8 end*/ + +/*deconv depthwise int8 begin*/ +void DeconvDwInt8BorderPixel(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + int32_t *dst_kh = dst; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDwInt8Border(int32_t *dst, const int16_t *src, const int16_t *weight, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const int16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int32_t *dst_h = dst + oh * sliding->in_h_step_; + + const int16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + int32_t *dst_w = dst_h + ow * C4NUM; + + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + int32_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + int32_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int32_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + int32_t *dst_kh = dst_w; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +#ifndef ENABLE_ARM +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max) { + int8_t *dst_k = dst; + int32_t *buffer_k = output_buffer; + for (int k = 0; k < pixel_nums; k++) { + for (int c = 0; c < C4NUM; c++) { + buffer_k[c] += bias[c]; + buffer_k[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer_k[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + buffer_k[c] += out_zp; + buffer_k[c] = MSMAX(buffer_k[c], acc_min); + buffer_k[c] = MSMIN(buffer_k[c], acc_max); + dst_k[c] = (buffer_k[c]); + } + dst_k += block_channel; + buffer_k += C4NUM; + } +} +#endif + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + int buffer_size = conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + memset(output_buffer, 0, buffer_size * sizeof(int32_t)); + const int16_t *src_data = src + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + DeconvDwInt8Border(output_buffer, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const int16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int16_t), + sliding->block_channel_ * sizeof(int16_t), sliding->in_sh_step_ * sizeof(int32_t), + sliding->in_sw_step_ * sizeof(int32_t), sliding->in_kh_step_ * sizeof(int32_t), + sliding->in_kw_step_ * sizeof(int32_t)); +#else + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwInt8Post(dst_data, output_buffer, bias, sliding->block_channel_, + conv_param->output_h_ * conv_param->output_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise int8 end*/ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h new file mode 100644 index 00000000..60150821 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_depthwise_int8.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONV_DEPTHWISE_H_ +#define NNACL_INT8_CONV_DEPTHWISE_H_ + +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding); + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id); + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c new file mode 100644 index 00000000..9904230c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.c @@ -0,0 +1,913 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/conv_int8.h" + +#ifdef ENABLE_ARM32 +void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); + +#ifdef ENABLE_ARM32 + size_t oc_div2 = output_channel / C2NUM * C2NUM; + size_t oc_res2 = output_channel - oc_div2; + size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); +#else + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci2div = ci / C2NUM, ci2mod = ci % C2NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} +#endif + +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); +#ifdef ENABLE_ARM64 + size_t oc_div4 = output_channel / C4NUM * C4NUM; + size_t oc_res4 = output_channel - oc_div4; + size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); +#else + + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} + +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) { + int ic4 = UP_ROUND(input_channel, C4NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); + int hw8 = UP_ROUND(plane_size, C8NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t oc_8div = output_channel / C8NUM * C8NUM; + size_t oc_8res = output_channel - oc_8div; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + int32_t *input_sum_r = input_sum; + + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_oc = input_sum_r; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "dup v0.4s, v16.s[0] \n" + "dup v1.4s, v16.s[1] \n" + "dup v2.4s, v16.s[2] \n" + "dup v3.4s, v16.s[3] \n" + "dup v4.4s, v17.s[0] \n" + "dup v5.4s, v17.s[1] \n" + "dup v6.4s, v17.s[2] \n" + "dup v7.4s, v17.s[3] \n" + "mov x4, #0 \n" + "mov x10, %[filter_zp] \n" + "mov x11, %[input_sum_oc] \n" + + "7: \n" + "cmp x4, %[oc_8div] \n" + "beq 8f \n" + "add x4, x4, #8\n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.4s}, [x10], #16\n" + + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "add x11, x11, %[input_sum_stride] \n" + "b 7b \n" + + "8: \n" + "cmp %[oc_8res], #0\n" + "beq 17f \n" + + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "cmp %[oc_8res], #1\n" + "beq 9f \n" + "cmp %[oc_8res], #2\n" + "beq 10f \n" + "cmp %[oc_8res], #3\n" + "beq 11f \n" + "cmp %[oc_8res], #4\n" + "beq 12f \n" + "cmp %[oc_8res], #5\n" + "beq 13f \n" + "cmp %[oc_8res], #6\n" + "beq 14f \n" + "cmp %[oc_8res], #7\n" + "beq 15f \n" + + "9: \n" + "ld1 {v16.s}[0], [x10] \n" + "b 16f \n" + + "10: \n" + "ld1 {v16.d}[0], [x10] \n" + "b 16f \n" + + "11: \n" + "ld1 {v16.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v16.s}[2], [x10] \n" + "b 16f \n" + + "12: \n" + "ld1 {v16.4s}, [x10] \n" + "b 16f \n" + + "13: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.s}[0], [x10] \n" + "b 16f \n" + + "14: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "b 16f \n" + + "15: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v17.s}[2], [x10] \n" + "b 16f \n" + + "16: \n" + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "17: \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [filter_zp] "r"(filter_zp), [input_sum_oc] "r"(input_sum_oc), + [input_sum_stride] "r"(input_sum_stride), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div), + [ic_4res] "r"(ic_4res), [oc_8div] "r"(oc_8div), [oc_8res] "r"(oc_8res) + : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; + input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; + input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; + input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; + input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; + input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; + input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; + input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; + } + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = 0; + } + } + } /* oc8 res done */ +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + input_sum_r += C8NUM * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t *input_sum_oc = input_sum_r; + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int curoi = 0; curoi < C8NUM; curoi++) { + input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + input_sum_oc[oci] = 0; + } + } /* oc8 res done */ + + src_r += input_channel; + pack_r += C4NUM; + input_sum_r += C8NUM; + } + + for (int hwi = plane_size; hwi < hw8; hwi++) { + for (int oc = 0; oc < oc8; oc++) { + int oc8div = oc / C8NUM, oc8res = oc % C8NUM; + input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; + } + } + } + return; +} + +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, const ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v16.4s, v16.4s, v20.4s \n" + "mul v17.4s, v17.4s, v20.4s \n" + + "st1 {v16.4s}, [x14], #16 \n" + "st1 {v17.4s}, [x14], #16 \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [input_sum_r] "r"(input_sum_r), [src_stride] "r"(src_stride), + [ic_4div] "r"(ic_4div), [ic_4res] "r"(ic_4res), [filter_zp] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp, + const ConvParameter *conv_param) { + size_t hw = conv_param->output_h_ * conv_param->output_w_; + size_t hw4 = UP_ROUND(hw, C4NUM); + size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); + } else { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#endif + } + return; +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, + int block_index, const int32_t *filter_zp, int32_t *input_sum, + const ConvParameter *conv_param, bool per_channel, bool is_optimize) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int kernel_plane = kernel_h * kernel_w; + NNACL_CHECK_ZERO_RETURN(out_w); + NNACL_CHECK_ZERO_RETURN(dilation_h); + NNACL_CHECK_ZERO_RETURN(dilation_w); + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = input_h * in_w * in_channel + input_w * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s || kh_e <= kh_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel); + } + } // kernel_h loop + } + } // tile num loop + int deep = kernel_plane * in_channel; + if (is_optimize) { + if (per_channel) { + Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num, + filter_zp, C8NUM * C8NUM); + } else { + Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param); + } + } else { + RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep); + if (per_channel) { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_); +#endif + } else { + size_t hw4 = UP_ROUND(real_cal_num, C4NUM); + size_t ic16 = UP_ROUND(deep, C16NUM); + PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, + ic16); + } + } +} + +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) { + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int tile_n = conv_param->tile_num_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_tile_count = UP_DIV(output_count, tile_n); + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + int unit_size; + int input_sum_offset; + int up_round_oc; +#ifdef ENABLE_ARM32 + up_round_oc = UP_ROUND(out_channel, C2NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); +#else + if (is_optimize) { + up_round_oc = UP_ROUND(out_channel, C8NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C4NUM); + } else { + up_round_oc = UP_ROUND(out_channel, C4NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); + } +#endif + bool per_channel = false; + if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { + input_sum_offset = tile_n * up_round_oc; + per_channel = true; + } else { + input_sum_offset = tile_n; + per_channel = false; + } + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; + int8_t *gemm_input = packed_input + task_id * unit_size * tile_n; + int8_t *matmul = matmul_input + task_id * kernel_plane * in_channel * tile_n; + memset(matmul, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, kernel_plane * in_channel * tile_n); + Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, matmul, real_cal_num, start_index, filter_zp, + tmp_input_sum, conv_param, per_channel, is_optimize); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + int8_t *gemm_output = output_data + out_offset; +#ifdef ENABLE_ARM32 + MatmulInt8Neon32( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, out_channel, per_channel); +#elif ENABLE_ARM64 + if (is_optimize) { + matmul_func(gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, + tmp_input_sum, bias_data, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0], per_channel); + } else { + MatmulInt8Neon64(gemm_input, packed_weight, gemm_output, UP_ROUND(real_cal_num, C4NUM), + UP_ROUND(out_channel, C4NUM), unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, real_cal_num, out_channel, out_channel, per_channel); + } +#else + MatMulInt8_8x8_r( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, tmp_input_sum, + bias_data, conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); +#endif + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h new file mode 100644 index 00000000..18318489 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/conv_int8.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c new file mode 100644 index 00000000..56205628 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.c @@ -0,0 +1,236 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/crop_parameter.h" +#include +#include +#include +#include "nnacl_c/int8/crop_int8.h" + +void Int8Crop1D(const int8_t *input, int8_t *output, int *output_shape, int64_t *in_offset, int task_id, + int thread_count, const CropQuantArg *quant) { + const int out_batch = output_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const int8_t *in_ptr = input + n + in_offset[0]; + int8_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + return; +} + +void Int8Crop2D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + int8_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + return; +} + +void Int8Crop3D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_width); + } else { + for (int i = 0; i < out_width; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + return; +} + +void Int8Crop4D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + const int in_channel = input_shape[3]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + const int out_channel = output_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_channel); + } else { + for (int i = 0; i < out_channel; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + } + return; +} + +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant) { + switch (input_dim) { + case 1: + Int8Crop1D(input, output, output_shape, in_offset, task_id, thread_count, quant); + break; + case 2: + Int8Crop2D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 3: + Int8Crop3D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 4: + Int8Crop4D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h new file mode 100644 index 00000000..b4cc4f26 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/crop_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CROP_INT8_H_ +#define NNACL_INT8_CROP_INT8_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CROP_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c new file mode 100644 index 00000000..cc2994e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.c @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/deconv_int8.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/common_func_int8.h" +int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + const ConvParameter *conv_param) { + /* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */ + int input_plane = conv_param->input_w_ * conv_param->input_h_; + int kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int in_plane4 = UP_ROUND(input_plane, C4NUM); + + int src_iw_stride = C4NUM; + int src_ih_stride = conv_param->input_w_ * C4NUM; + int src_kw_stride = in_plane4 * C4NUM; + int src_kh_stride = in_plane4 * conv_param->kernel_w_ * C4NUM; + int dst_oh_stride = conv_param->output_w_ * C4NUM; + int dst_ow_stride = C4NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C4NUM; + int dst_kw_stride = conv_param->dilation_w_ * C4NUM; + + for (int c = 0; c < oc4; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C4NUM; + const int32_t *src_ptr = src + c * in_plane4 * kernel_plane * C4NUM; + memset(dst_ptr, 0, (size_t)output_plane * C4NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + int32_t *tmp_dst = dst_ptr + dst_index; + const int32_t *tmp_src = src_ptr + src_index; +#ifndef ENABLE_ARM64 + for (int i = 0; i < C4NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#else + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s}, [x0] \n" + "ld1 {v1.4s}, [x1] \n" + + "add v0.4s, v0.4s, v1.4s \n" + + "st1 {v0.4s}, [x1] \n" + + : + : [tmp_src] "r"(tmp_src), [tmp_dst] "r"(tmp_dst) + : "x0", "x1", "v0", "v1"); +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc*/ + + PostFuncInt8C4(tmp, bias, out, output_channel, (size_t)output_plane, conv_param->output_channel_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return NNACL_OK; +} + +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_) { + /* optimize normal -> same layout */ + int ic16 = UP_ROUND(input_channel, C16NUM); + int oc4 = UP_ROUND(output_channel, C4NUM); + for (int ic = 0; ic < input_channel; ic++) { + int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * output_channel * plane + hw * output_channel + oc; + int dst_index = hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt) { + int deep16 = UP_ROUND(deep, C16NUM); + int32_t zp_sum = filter_zp * input_zp * deep; + for (int c = 0; c < col4; c++) { + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + for (int r = 0; r < deep; r++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; + value += weight[src_index]; + } + weight_sum[c] = zp_sum - value * input_zp; + } + return; +} + +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt) { + /* optimize normal -> same layout */ + PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16); + return; +} + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func) { + if (matmul_func != NULL) { + matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } else { + MatMulInt8_16x4(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } + return NNACL_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize) { + /* optimize normal -> same layout (C4) */ + int error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); + return error_code; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h new file mode 100644 index 00000000..60a7a31d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/deconv_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DECONV_H_ +#define NNACL_INT8_DECONV_H_ + +#include +#include "nnacl_c/pack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt); +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt); +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_); + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func); +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DECONV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c new file mode 100644 index 00000000..be317eea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/depth_to_space_int8.h" +#include + +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + int64_t copy_size = block_size * param->out_stride_dim2_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = out_offset_w + l * param->out_stride_dim1_; + int64_t in_offset = in_offset_w + l * block_size * param->out_stride_dim2_; + for (int m = 0; m < copy_size; ++m) { + int32_t output_tmp = round(input[in_offset + m] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[out_offset + m] = output_tmp; + } + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h new file mode 100644 index 00000000..faaba589 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/depth_to_space_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ +#define NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ + +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c new file mode 100644 index 00000000..2dc62cd3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/div_int8.h" + +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + const int32_t input1_val = para->in1_args_.zp_ + *input1_data; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h new file mode 100644 index 00000000..dcd6210d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/div_int8.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DIV_INT8_H_ +#define NNACL_INT8_DIV_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DIV_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c new file mode 100644 index 00000000..a1aad3b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.c @@ -0,0 +1,76 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl_c/int8/dynamic_gather_int8.h" +#include "nnacl_c/op_base.h" + +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#endif + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float16_t *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float16_t *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (float16_t)(src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64ForFp16(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (float16_t)((src[j] - zp) * scale); + } +#endif + } + } + return; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h new file mode 100644 index 00000000..b51491b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_gather_int8.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_GATHER_INT8_H_ +#define NNACL_INT8_DYNAMIC_GATHER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float scale); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_GATHER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c new file mode 100644 index 00000000..5c4d49d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.c @@ -0,0 +1,420 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/dynamic_matmul_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float16_t) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} +#endif + +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type) { + /* * + * row4x16-major * row16x4-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + int32_t s0 = 0; + int32_t s1 = 0; + int32_t s2 = 0; + int32_t s3 = 0; + for (int d = 0; d < deep; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + s0 += a[ai] * b[bi]; + s1 += filter_zp * a[ai]; + s2 += input_zp * b[bi]; + s3 += input_zp * filter_zp; + } + value = s0 - s1 - s2 + s3; + int input_quant_index = input_per_channel ? r : 0; + int filter_quant_index = filter_per_channel ? c : 0; + float multi_scale = input_scale[input_quant_index] * filter_scale[filter_quant_index]; + size_t ci = r * stride + c; + dst[ci] = multi_scale * value; + if (bias != NULL) { + dst[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + dst[ci] = MSMAX(0, dst[ci]); + } else if (act_type == ActType_Relu6) { + dst[ci] = MSMAX(0, dst[ci]); + dst[ci] = MSMIN(C6NUM, dst[ci]); + } + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size_t input_channel) { + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v2.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "6: \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div), + [ic_4res] "r"(ic_4res) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); +} +#endif + +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; +#ifdef ENABLE_ARM64 + PackInput4x4Asm(src_ic, pack_ic, ic_4div, input_channel); +#else + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + memset(pack_r, 0, C4NUM * ic4); + for (int hwi = hw_4div; hwi < plane_size; hwi += 1) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + src_r += input_channel; + pack_r += C4NUM; + } + } + return; +} + +// For matmul input a transpose case +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; +#ifdef ENABLE_ARM64 + size_t row_stride_int64 = row_stride; + asm volatile( + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "bgt 2b\n" + "1:\n" + + : + : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [row] "r"(row_div), [row_stride] "r"(row_stride_int64) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + } + } +} + +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * col + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} + +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * stride + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h new file mode 100644 index 00000000..5b36268e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_matmul_int8.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_MATMUL_H_ +#define NNACL_INT8_DYNAMIC_MATMUL_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride); +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size); +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type); +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order); +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order); +#if defined(ENABLE_ARM64) && !defined(USE_AOS_GCC_TOOLCHAIN) +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#ifdef ENABLE_FP16 +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c new file mode 100644 index 00000000..3cd0669c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.c @@ -0,0 +1,91 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/dynamic_quant_int8.h" + +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) { + if (count == 0) { + return; + } +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#else + // avoid to compile optimize. + volatile int count_4 = DOWN_ROUND(count, C4NUM); + asm volatile( + "mov x4, %[data]\n" // reload data + "mov w5, %w[count_4]\n" // reload count + "ld1 {v31.4s}, [x4]\n" // min + "ld1 {v30.4s}, [x4], #16\n" // max + "subs w5, w5, #4\n" + "ble 1f\n" + + "0:\n" + "ld1 {v0.4s}, [x4], #16\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "subs w5, w5, #4\n" + "bgt 0b\n" + + "1:\n" + "fminv s6, v31.4s\n" + "fmaxv s7, v30.4s\n" + + "str s6, [%[real_min]]\n" + "str s7, [%[real_max]]\n" + + : + : [data] "r"(data), [count_4] "r"(count_4), [real_min] "r"(real_min), [real_max] "r"(real_max) + : "x4", "w5", "s6", "s7", "v0", "v30", "v31"); + for (int i = count_4; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#endif +} + +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int channel_total = count / row_length; + for (int i = 0; i < channel_total; i++) { + CalculateMinMaxFp32(data + i * row_length, row_length, real_min + i, real_max + i); + } +} + +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int row_total = count / row_length; + for (int r = 0; r < row_total; r++) { + const float *data_current = data + r * row_length; + for (int c = 0; c < row_length; c++) { + float *real_min_channel = real_min + c; + float *real_max_channel = real_max + c; + if (data_current[c] < *real_min_channel) { + *real_min_channel = data_current[c]; + } + if (data_current[c] > *real_max_channel) { + *real_max_channel = data_current[c]; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h new file mode 100644 index 00000000..05f26e68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/dynamic_quant_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_INT8_H_ +#define NNACL_INT8_DYNAMIC_QUANT_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max); +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c new file mode 100644 index 00000000..e77ba048 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.c @@ -0,0 +1,276 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/fixed_point.h" + +#define C31NUM 31 + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b) { + if (a == INT_MIN && b == INT_MIN) { + return INT_MAX; + } + int64_t ab = ((int64_t)a) * ((int64_t)b); + int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); + // do not apply right shift to potential negetive values + int ab_mantissa = (int)((ab + rounding) / (1ll << 31)); + return ab_mantissa; +} + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { + if (a == SHRT_MIN && b == SHRT_MIN) { + return SHRT_MAX; + } + int32_t ab = ((int32_t)a) * ((int32_t)b); + int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14)); + return (int16_t)((ab + rounding) / (1ll << 15)); +} + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent) { + if (exponent > C31NUM) { + exponent = C31NUM; + } + const int mask = (1ll << exponent) - 1; + const int remainder = x & mask; + const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + (remainder > threshold ? 1 : 0); +} + +int UpwardRounding(int x, int exponent) { + const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0; + if (x > INT32_MAX - rounding_offset) { + return 1 << (31 - exponent); + } + return (x + rounding_offset) >> exponent; +} + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift) { + return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); +} + +int FractionsBits(int integer_bits) { return 8 * (int)(sizeof(int32_t)) - 1 - integer_bits; } + +int FixedPoint_One(int integer_bits, int fractions_bits) { + return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits))); +} + +int RoundingHalfSum(int32_t a, int32_t b) { + int64_t sum = (int64_t)a + (int64_t)b; + return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2); +} + +int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } + +int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } + +int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } + +int32_t BitNot(int32_t a) { return ~(uint32_t)a; } + +int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } + +int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); } + +int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } + +int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } + +int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } + +uint32_t CountLeadingZeroBits(uint32_t x) { + if (x == 0) { + return 8 * sizeof(uint32_t) - 1; + } +#if defined(__GUNC__) + return __builtin_clz(x); +#else + const uint32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1); + uint32_t leading_zeros = 0; + while (x < leading_positive) { + x <<= 1; + leading_zeros++; + } + return leading_zeros; +#endif +} + +uint32_t CountLeadingSignBits(int32_t x) { + if (x == 0) { + return 8 * sizeof(int32_t) - 1; + } +#if defined(__GUNC__) && !defined(__clang__) + return __builtin_clrsb(x); +#else + return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0; +#endif +} + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) { + if (exponent > 0) { + const int min = INT32_MIN; + const int max = INT32_MAX; + const int scalar_int_bits = 8 * (int)(sizeof(int32_t)); + const int threshold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1); + const int positive_mask = x > threshold ? BitNot(0) : 0; + const int negative_mask = x < -threshold ? BitNot(0) : 0; + int result = x * ((int32_t)(1) << (uint32_t)exponent); + result = BitsSelect(positive_mask, max, result); + result = BitsSelect(negative_mask, min, result); + return result; + } else if (exponent < 0) { + return RoundingDivideByPOT(x, -exponent); + } else { + return x; + } +} + +int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) { + int exponent = integer_bits_src - integer_bits_dst; + return SaturatingRoundingMultiplyByPOT(x, exponent); +} + +int32_t reciprocal_on_interval_between_0_1(int32_t a) { + int one = FixedPoint_One(0, FractionsBits(0)); + int half_sum = RoundingHalfSum(a, one); + const int constant_48_over_17 = 1515870810; + const int constant_neg_32_over_17 = -1010580540; + int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17); + for (int i = 0; i < 3; i++) { + int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x); + int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x; + x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2); + } + return Rescale(x, 2 - 1, 0); +} + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift) { + uint32_t leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); + *recip_shift = x_digits - leading_zreos_plus_one; + const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); + const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one); + return shifted_scaled; +} + +int exp_on_interval_values(int a) { + const int constant_neg_1_over_8 = 1895147668; + const int constant_1_over_3 = 715827883; + int fractional_bits = FractionsBits(0); + int x = a + ConstantPOT(fractional_bits, -3); + int x2 = SaturatingRoundingDoublingHighMul(x, x); + int x3 = SaturatingRoundingDoublingHighMul(x2, x); + int x4 = SaturatingRoundingDoublingHighMul(x2, x2); + int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); + int x4_over_24_plus_x3_over_6_plus_x2_over_2 = + SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); + return constant_neg_1_over_8 + + SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} + +void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder, + int32_t *result) { + if (integer_bits > exponent) { + int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0; + *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))), + SaturatingRoundingDoublingHighMul(*result, muliplier), *result); + } +} + +int exp_on_negative_values(int a, const int integer_bits) { + int fractional_bits = FractionsBits(integer_bits); + const int one_quarter = ConstantPOT(fractional_bits, -2); + int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter; + int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0)); + int remainder = a_mod_quarter_minus_one_quarter - a; + + exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result); + + int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0; + if (integer_bits > 5) { + const int clamp = -(1 << (uint32_t)clamp_bits); + result = BitsSelect(MaskIfLessThan(a, clamp), 0, result); + } + result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result); + return result; +} + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) { + if (input <= 1) { + *multiplier = INT_MAX; + *shift = 0; + } + *shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*shift; + } + uint32_t max_left_shift_bits = CountLeadingSignBits(input); + if (max_left_shift_bits < 2) { + return; + } + uint32_t left_shift_bit_pairs = max_left_shift_bits / 2 - 1; + *shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + int32_t fixedpoint_f3_input = input >> 1; // sign: 1 bit, integer: 3 bit, fractional: 28 bit + int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1); + int32_t fp_f3_half_three = (1 << 28) + (1 << 27); + int32_t tmp = (1 << 28); // one + for (int i = 0; i < 5; i++) { + int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3); + tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) - + SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3), + 6, 3); + } + const int32_t fp_f0_half_sqrt_2 = 1518500250; // sqrt(2) / 2 + tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2); + *multiplier = tmp; + if (*shift < 0) { + *multiplier <<= -*shift; + *shift = 0; + } + *shift *= reverse_shift; +} + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); } +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h new file mode 100644 index 00000000..503a5e1d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/fixed_point.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_FIXED_POINT_H_ +#define NNACL_QUANTIZATION_FIXED_POINT_H_ + +#include +#include +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b); + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b); + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent); + +int UpwardRounding(int x, int exponent); + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift); + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); + +int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); + +uint32_t CountLeadingSignBits(int32_t x); + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift); + +int exp_on_negative_values(int a, const int tIntegerBits); + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift); + +#ifdef __cplusplus +} +#endif + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent); + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b); +#endif + +#endif // NNACL_QUANTIZATION_FIXED_POINT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c new file mode 100644 index 00000000..22d06187 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.c @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/gatherNd_int8.h" +#include +#include "nnacl_c/errorcode.h" + +int GatherNdInt8(int8_t *input, int8_t *output, const int32_t *in_offset, int area, int count, GatherQuantArg param) { + double alpha = param.alpha_; + int z1 = param.zp_in_; + int z2 = param.zp_out_; + for (int i = 0; i < count; ++i) { + for (int j = 0; j < area; ++j) { + int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + output[area * i + j] = (int8_t)tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h new file mode 100644 index 00000000..91c74ab6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gatherNd_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHERND_INT8_H_ +#define NNACL_INT8_GATHERND_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNdInt8(int8_t *in_data, int8_t *out_data, const int32_t *in_offset, int area, int count, + GatherQuantArg param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHERND_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c new file mode 100644 index 00000000..148eb0c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/errorcode.h" + +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h new file mode 100644 index 00000000..86d3664b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/gather_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHER_INT8_H_ +#define NNACL_INT8_GATHER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para); + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c new file mode 100644 index 00000000..01393dac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.c @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/hswish_int8.h" + +int16_t SaturatingLeftShift(int16_t value, int shift_num) { + int32_t result = (int32_t)value * (1 << shift_num); + return MSMAX(MSMIN(result, SHRT_MAX), SHRT_MIN); +} + +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg) { + for (int i = 0; i < length; i++) { + const int16_t input_value = src[i] - arg->input_zp; + const int16_t input_value_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + SaturatingRoundingDoublingHighMulInt16(input_value_scale, arg->output_multiplier_fixedpoint_int16); + int16_t relu6_value = input_value_scale; + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, arg->relu6_multiplier_exponent - 1); + } + relu6_value = SaturatingRoundingDoublingHighMulInt16(relu6_value, arg->relu6_multiplier_fixedpoint_int16); + + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, 1); + } + if (arg->relu6_multiplier_exponent < 0) { + relu6_value = RoundingDivideByPOT(relu6_value, -arg->relu6_multiplier_exponent); + } + relu6_value = (size_t)(relu6_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + SaturatingRoundingDoublingHighMulInt16(relu6_value, input_value_on_preshift_output_scale); + + int16_t output = RoundingDivideByPOT(preshift_output_value, -arg->output_multiplier_exponent); + output += arg->output_zp; + output = MSMIN(output, 127); + output = MSMAX(output, -128); + dst[i] = (int8_t)output; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h new file mode 100644 index 00000000..688f62ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/hswish_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_HSWISH_INT8_H_ +#define NNACL_INT8_HSWISH_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" + +typedef struct HswishQuantArg { + double input_scale; + int32_t input_zp; + double output_scale; + int32_t output_zp; + int16_t relu6_multiplier_fixedpoint_int16; + int32_t relu6_multiplier_exponent; + int16_t output_multiplier_fixedpoint_int16; + int32_t output_multiplier_exponent; +} HswishQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_HSWISH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c new file mode 100644 index 00000000..233ff4a1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.c @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/l2_norm_int8.h" +#include +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/errorcode.h" + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end) { + const int inner_size = param->shape_[param->shape_num_ - 1]; + + for (int i = begin; i < end; ++i) { + int32_t square_sum = 0; + for (int j = 0; j < inner_size; ++j) { + int32_t in = input_data[i * inner_size + j] - quant_param->in_.zp_; + square_sum += in * in; + } + int32_t multiplier; + int32_t shift; + GetSqrtQuantMultiplierExp(square_sum, -1, &multiplier, &shift); + for (int k = 0; k < inner_size; ++k) { + int32_t in = input_data[i * inner_size + k] - quant_param->in_.zp_; + int32_t out = RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(in * (1 << 7), multiplier), -shift); + output_data[i * inner_size + k] = MSMIN(127, MSMAX(-128, out)); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h new file mode 100644 index 00000000..4cfec797 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/l2_norm_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_L2_NORM_INT8_H_ +#define NNACL_INT8_L2_NORM_INT8_H_ + +#include "nnacl_c/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_L2_NORM_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c new file mode 100644 index 00000000..81d829c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/layer_norm_int8.h" + +void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data, + const LayerNormQuantArg *quant, int num, const float mean, const float deno) { + for (int i = 0; i < num; i++) { + float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = (fp32_src - mean) * deno; + fp32_dst = fp32_dst * gamma_data[i] + beta_data[i]; + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + dst[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } +} + +/* + * origin : (x-mean) / sqrt(variance + epsilon) * gamma + beta + * quant : (x-mean) / sqrt(sum(x * x) - mean * mean) * gamma + beta + * + * */ +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = NNACL_MIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const int8_t *src_norm = src_data + i * param->norm_inner_size_; + int8_t *dst_norm = dst_data + i * param->norm_inner_size_; + float mean = 0.0f; + float square_mean = 0.0f; + for (int j = 0; j < param->norm_inner_size_; j++) { + float float_src = (src_norm[j] - quant->in_zp_) * quant->in_scale_; + mean += float_src; + square_mean += float_src * float_src; + } + mean /= (float)param->norm_inner_size_; + square_mean /= (float)param->norm_inner_size_; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const int8_t *src_param = src_norm + x * param->params_inner_size_; + int8_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaInt8(dst_param, src_param, gamma_data, beta_data, quant, param->norm_inner_size_, mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaInt8(dst_norm, src_norm, gamma, beta, quant, param->norm_inner_size_, mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h new file mode 100644 index 00000000..e33b34b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/layer_norm_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_LAYER_NORM_H_ +#define NNACL_INT8_LAYER_NORM_H_ + +#include "nnacl_c/errorcode.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_LAYER_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c new file mode 100644 index 00000000..8b38c5f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.c @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/leaky_relu_int8.h" +#include "nnacl_c/errorcode.h" + +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_prelu_parm, int task_id) { + if (quant_prelu_parm == NULL) { + return NNACL_NULL_PTR; + } + float output_scale = quant_prelu_parm->out_args_.scale_; + int output_zp = quant_prelu_parm->out_args_.zp_; + const float output_inverse_scale = 1.f / output_scale; + + float scale = quant_prelu_parm->in_args_.scale_ * output_inverse_scale; + float bias = -quant_prelu_parm->in_args_.zp_ * scale; + for (int j = task_id; j < quant_prelu_parm->element_num; j += quant_prelu_parm->thread_num_) { + if (inputs[j] <= 0) { + int32_t output_tmp = round(inputs[j] * quant_prelu_parm->slope_ * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } else { + int32_t output_tmp = round(inputs[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h new file mode 100644 index 00000000..cde80295 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/leaky_relu_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PRELU_INT8_H_ +#define NNACL_INT8_PRELU_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_Prelu_parm, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PRELU_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c new file mode 100644 index 00000000..cfe16eb3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.c @@ -0,0 +1,839 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col16 = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int rd2 = r / C2NUM; + int rm2 = r % C2NUM; + for (int c = 0; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + int dst_index = rd2 * col16 * C2NUM + cd16 * C2NUM * C16NUM + rm2 * C16NUM + cm16; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row4x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_div = row / C4NUM * C4NUM; + int col_4 = UP_ROUND(col, C4NUM); + int col_div = col / C4NUM * C4NUM; + + const int8_t *src_r4 = src; + int8_t *packed_r4 = dst; + const int8_t *src_c4 = NULL; + int8_t *packed_c4 = NULL; + for (int r = 0; r < row_div; r += C4NUM) { + src_c4 = src_r4; + packed_c4 = packed_r4; + + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < C4NUM; i++) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + + if (col == col_div) { + continue; + } + memset(packed_c4, 0, C16NUM * sizeof(int8_t)); + for (int i = 0; i < C4NUM; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } + src_r4 += C4NUM * col; + packed_r4 += C4NUM * col_4; + } + + if (row == row_div) { + return; + } + memset(packed_r4, 0, C4NUM * col_4); + src_c4 = src_r4; + packed_c4 = packed_r4; + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < row - row_div; ++i) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + if (col == col_div) { + return; + } + for (int i = 0; i < row - row_div; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } +} + +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int row16 = UP_ROUND(row, C16NUM); + int stride = C16NUM * C2NUM; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C2NUM * (row16 / C16NUM) + r / C16NUM; + int dst_idx = stride * stride_idx + c % C2NUM * C16NUM + r % C16NUM; + int src_idx = r * col + c; + dst_ptr[dst_idx] = src_ptr[src_idx]; + } + } +} + +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd8 = r / C8NUM; + int rm8 = r % C8NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void MatrixPack4x16UnitInt8(const int8_t *src, int8_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const int8_t *src_r = src + r * stride; + int8_t *dst_r = dst + r * C16NUM; + memcpy(dst_r, src_r, col * sizeof(int8_t)); + } + return; +} + +void MatrixEmptyInt8(int8_t *dst, int row, int col) { + for (int r = 0; r < row; r++) { + int8_t *dst_r = dst + r * C16NUM; + memset(dst_r, 0, col * sizeof(int8_t)); + } + return; +} + +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd16 = r / C16NUM; + int rm16 = r % C16NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + /* Row-major to row16x4-major (block row-major) */ + int col16 = UP_ROUND(col, C16NUM); + int row_4div = row / C4NUM * C4NUM; + int row_4res = row - row_4div; + int col_16div = col / C16NUM * C16NUM; + int col_16res = col - col_16div; + int8_t *src_r = (int8_t *)src_ptr; + int8_t *dst_r = (int8_t *)dst_ptr; + + for (int ri = 0; ri < row_4div; ri += C4NUM) { + for (int ci = 0; ci < col_16div; ci += C16NUM) { + size_t col_offset = (size_t)col; + int8_t *src_c = src_r + ci; + int8_t *dst_c = dst_r + ci * C4NUM; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x10, %[src_c] \n" + "mov x11, %[dst_c] \n" + + "ld1 {v0.16b}, [x10], %[col_offset]\n" + "ld1 {v1.16b}, [x10], %[col_offset]\n" + "ld1 {v2.16b}, [x10], %[col_offset]\n" + "ld1 {v3.16b}, [x10], %[col_offset]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "st1 {v2.16b}, [x11], #16\n" + "st1 {v3.16b}, [x11], #16\n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [col_offset] "r"(col_offset) + : "x10", "x11", "v0", "v1", "v2", "v3"); +#elif ENABLE_ARM32 + asm volatile( + "mov r0, %[src_c] \n" + "mov r1, %[dst_c] \n" + "mov r2, %[col_offset] \n" + "mov r3, #16 \n" + + "vld1.8 {q0}, [r0], r2 \n" + "vld1.8 {q1}, [r0], r2 \n" + "vld1.8 {q2}, [r0], r2 \n" + "vld1.8 {q3}, [r0], r2 \n" + + "vst1.32 {d0, d1}, [r1], r3 \n" + "vst1.32 {d2, d3}, [r1], r3 \n" + "vst1.32 {d4, d5}, [r1], r3 \n" + "vst1.32 {d6, d7}, [r1], r3 \n" + + : + : [dst_c] "r"(dst_c), [src_c] "r"(src_c), [col_offset] "r"(col_offset) + : "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3"); +#else + MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset); +#endif + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); + MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res); + } + src_r += C4NUM * col; + dst_r += C4NUM * col16; + } + + if (row != row_4div) { + memset(dst_r, 0, C4NUM * col16); + + for (int ci = 0; ci < col_16div; ci += C16NUM) { + MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, row_4res, col_16res, col); + } + } + return; +} + +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias) { + /* row4x16-major * row16x4-major => row4x4-major */ + for (int r = 0; r < row_4; r++) { + for (int c = 0; c < col_4; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = c4div * row_4 * C4NUM + r * C4NUM + c4mod; + int32_t value = 0; + for (int d = 0; d < deep_16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + value -= input_sum[r]; + value += bias[c]; + ((int32_t *)dst)[ci] = value; + } + } + return; +} + +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc) { + /* support per-layer && weight per-channel */ + /* row4x16-major * row16x2-major => (int8)row-major*/ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c2div = c / C2NUM, c2mod = c % C2NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_16; d++) { + size_t d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c2div * deep_16 * C2NUM + d16div * C2NUM * C16NUM + c2mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + peroc ? input_sum[c2div * UP_ROUND(row, C4NUM) * C2NUM + r * C2NUM + c2mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifndef ENABLE_ARM +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int mini, int maxi, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp) { + /* + * row4x16-major * row16x4-major => (int8)row-major + * support per-layer && weight per-channel + * a_sums is perT : input_row_sum * filter_zp + * perOc : input_row_sum + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = r * stride + c; + int32_t value = 0; + for (int d = 0; d < deep16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} +#endif +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel) { + /* row8x4-major * row4x8-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r8div = r / C8NUM, r8mod = r % C8NUM; + size_t c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod; + size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp) { + /* row4x4-major * row4x16-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c16div = c / C16NUM, c16mod = c % C16NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *pack_ic, int32_t *input_sum_r, size_t src_stride, + size_t ic_4div, size_t ic_4res, int32_t filter_zp) { + asm volatile( + "dup v2.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v3.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "6: \n" + "mul v2.4s, v2.4s, v3.4s \n" + + "st1 {v2.4s}, [x14], #16 \n" + + : + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [input_sum_r] "r"(input_sum_r), [src_stride] "r"(src_stride), + [ic_4div] "r"(ic_4div), [ic_4res] "r"(ic_4res), [filter_zp] "r"(filter_zp) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); + return; +} +#endif +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp) { + size_t ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (size_t hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + PackInput4x4AndInputSumPert_arm64(src_ic, pack_ic, input_sum_r, src_stride, ic_4div, ic_4res, filter_zp); +#else + int32_t tmp_sum_value[4] = {0}; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (size_t ici = input_channel; ici < ic4; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (size_t i = 0; i < C4NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + (void)memset(pack_r, 0, C4NUM * ic4); + for (size_t hwi = hw_4div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (size_t hwi = plane_size; hwi < hw4; hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput2Col4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *packed_ic, int32_t *input_sum, int row, + size_t row_stride, int32_t filter_zp) { + asm volatile( + "ld1 {v12.s}[0], [%[input_sum]]\n" + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "sxtl v6.8h, v12.8b\n" + "sxtl v12.4s, v6.4h\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "sxtl v6.8h, v0.8b\n" + "sxtl v7.4s, v6.4h\n" + "sxtl2 v8.4s, v6.8h\n" + "sxtl v9.8h, v1.8b\n" + "sxtl v10.4s, v9.4h\n" + "sxtl2 v11.4s, v9.8h\n" + "add v10.4s, v10.4s, v7.4s\n" + "add v10.4s, v10.4s, v8.4s\n" + "add v10.4s, v10.4s, v10.4s\n" + "add v10.4s, v10.4s, v11.4s\n" + "bgt 2b\n" + "1:\n" + + : + : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [input_sum] "r"(input_sum), [row] "r"(row), + [row_stride] "r"(row_stride), [filter_zp] "r"(filter_zp) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + + return; +} +#endif + +// For matmul input a transpose case +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + int32_t *tmp_sum = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; + tmp_sum = input_sum + c; +#ifdef ENABLE_ARM64 + PackInput2Col4x4AndInputSumPert_arm64(src_ic, packed_ic, tmp_sum, row_div, row_stride, filter_zp); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + + tmp_sum[0] += src_ic[i * row_stride + 0]; + tmp_sum[1] += src_ic[i * row_stride + 1]; + tmp_sum[2] += src_ic[i * row_stride + 2]; + tmp_sum[3] += src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + tmp_sum[i] += src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + tmp_sum = input_sum + col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + + tmp_sum[i] += src_ic[0 * row_stride + i]; + tmp_sum[i] += src_ic[1 * row_stride + i]; + tmp_sum[i] += src_ic[2 * row_stride + i]; + tmp_sum[i] += src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + tmp_sum[c] += src_ic[r * row_stride + c]; + } + } + + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } +} + +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_16 = UP_ROUND(row, C16NUM); + int stride = sizeof(int8_t) * 16 * 4; + for (int r = 0; r < row_16; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / 4 * (row_16 / 16) + r / 16; + if (r >= row) { + dst[stride * stride_idx + c % 4 * 16 + r % 16] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % 4 * 16 + r % 16] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C4NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C4NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < cur_oc; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order) { + for (int r = 0; r < row; ++r) { + int sum = 0; + for (int c = 0; c < col; ++c) { + if (order == RowMajor) { + sum += input[r * col + c]; + } else { + sum += input[c * row + r]; + } + } + sum *= weight_zp; + dst[r] = sum; + } +} + +// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * col + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} + +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * stride + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h new file mode 100644 index 00000000..3dec3b18 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/matmul_int8.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MATMUL_H_ +#define NNACL_INT8_MATMUL_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* 4x16 16x4 -> 4x4 */ +/* sdot 4x4 4x16 -> 4x16 */ +/* matmul */ +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc); +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp); +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order); +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel); +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel); +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp); +/* 8x4 4x8 -> 8x8 */ +/* optimize conv */ +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel); + +/* 4x16 16x2 -> 4x2 */ +/* arm32 conv1x1 */ +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc); + +/* 4x4 4x16 -> 4x16 */ +/* optimize conv1x1 */ +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp); +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp); + +#ifdef ENABLE_ARM64 +void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, + const int32_t *a_sums, const int32_t *bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, + int filter_peroc); + +void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, + const int32_t *input_sum, const int32_t *bias); +#endif +#ifdef ENABLE_ARM32 +void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, + const int32_t *input_sums, const int32_t *weight_bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int stride, int per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif // LITE_SRC_BACKEND_ARM_NNACL_INT8_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c new file mode 100644 index 00000000..d7e55934 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.c @@ -0,0 +1,238 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/mul_int8.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { + int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); + int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup); + raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec); + return vqmovn_s32(raw_sum); +} + +void MulInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg, int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + + for (; (*index) <= real_dst_count - 16; (*index) += 16) { + int16x8_t zp1_vec = vdupq_n_s16(quant_arg->in_quant_args_[0].zp_); + int16x8_t zp2_vec = vdupq_n_s16(quant_arg->in_quant_args_[1].zp_); + int8x16_t input0_vec = vld1q_s8(input0_data + *index); + int8x16_t input1_vec = vld1q_s8(input1_data + *index); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec, + output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + output_data += 16; + } + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, quant_arg->in_quant_args_[0].zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, quant_arg->in_quant_args_[1].zp_); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg) { + // input0 need broadcast + int32_t zp1 = quant_arg->in_quant_args_[0].zp_; + int32_t zp2 = quant_arg->in_quant_args_[1].zp_; + if (input1_broad) { + zp1 = quant_arg->in_quant_args_[1].zp_; + zp2 = quant_arg->in_quant_args_[0].zp_; + } +#ifdef ENABLE_NENO + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + int16x8_t zp1_vec = vdupq_n_s16(zp1); + int16x8_t zp2_vec = vdupq_n_s16(zp2); +#endif + for (int index = 0; index < real_dst_count; ++index) { + int j = 0; +#ifdef ENABLE_NENO + for (; j <= depth - 16; j += 16) { + int8x16_t input0_vec = vld1q_s8(input0_data + j); + int8x16_t input1_vec = vld1q_s8(input1_data); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + input1_data += 16; + output_data += 16; + } + for (; j <= depth - 8; j += 8) { + int8x8_t input0_vec = vld1_s8(input0_data + j); + int8x8_t input1_vec = vld1_s8(input1_data); + int16x8_t input0_val = vmovl_s8(input0_vec); + int16x8_t input1_val = vmovl_s8(input1_vec); + input0_val = vaddq_s16(input0_val, zp1_vec); + input1_val = vaddq_s16(input1_val, zp2_vec); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + input1_data += 8; + output_data += 8; + } +#endif + for (; j < depth; ++j) { + const int32_t input0_val = zp1 + input0_data[j]; + const int32_t input1_val = zp2 + input1_data[0]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[0] = (int8_t)mul_result; + input1_data++; + output_data++; + } + } + return; +} + +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg) { + int index = 0; +#ifdef ENABLE_NEON + MulInt8NEON(input0_data, input1_data, output_data, real_dst_count, quant_arg, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = quant_arg->in_quant_args_[0].zp_ + input0_data[index]; + const int32_t input1_val = quant_arg->in_quant_args_[1].zp_ + input1_data[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[index] = (int8_t)mul_result; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h new file mode 100644 index 00000000..ef88bba1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/mul_int8.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MUL_INT8_H_ +#define NNACL_INT8_MUL_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/mul_parameter.h" +#include "nnacl_c/int8/common_func_int8.h" +#include "nnacl_c/int8/fixed_point.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg); +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_MUL_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c new file mode 100644 index 00000000..25f0ecea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.c @@ -0,0 +1,452 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pack_int8.h" + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8_round = UP_ROUND(in_channel, C8NUM); + int ic8 = in_channel / C8NUM * C8NUM; + int in_plane = in_h * in_w; + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_plane; + int dst_batch_offset = b * ic8_round * in_plane; + for (int k = 0; k < in_plane; k++) { + int src_plane_offset = src_batch_offset + k * in_channel; + int dst_plane_offset = dst_batch_offset + k * C8NUM; + for (int i = 0; i < ic8; i += 8) { + int src_c_offset = src_plane_offset + i; + int dst_c_offset = dst_plane_offset + i * in_plane; +#ifdef ENABLE_ARM + vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset))); +#else + for (int j = 0; j < C8NUM; ++j) { + (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j]; + } +#endif + } // ic8_minus loop + int res_c = in_channel - ic8; + int tmp_ic_offset = ic8 * in_plane; + for (int l = 0; l < res_c; ++l) { + int src_c_offset = src_plane_offset + ic8 + l; + int dst_c_offset = dst_plane_offset + tmp_ic_offset + l; + (packed_input + dst_c_offset)[0] = (int16_t)(input_data + src_c_offset)[0]; + } // res ic loop + int res2 = ic8_round - in_channel; + for (int l = 0; l < res2; ++l) { + int dst_c_offset = dst_plane_offset + tmp_ic_offset + res_c + l; + (packed_input + dst_c_offset)[0] = 0; + } // res ic loop + } // kh * kw loop + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, + const ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = input_channel / C8NUM * C8NUM; + int ic8_round = UP_ROUND(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int32_t zp; + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + zp = filter_zp[0].zp_; + } else { + zp = filter_zp[o].zp_; + } + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8_round * kernel_plane; + int i = 0; + for (; i < ic8; i += C8NUM) { + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + i * kernel_plane; +#ifdef ENABLE_ARM64 + int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); + int16x8_t src_s16 = vmovl_s8(src_s8); + int16x4_t src1_s16 = vget_low_s16(src_s16); + int16x4_t src2_s16 = vget_high_s16(src_s16); + int32x4_t src1_s32 = vmovl_s16(src1_s16); + int32x4_t src2_s32 = vmovl_s16(src2_s16); + int32x4_t zp_s32 = vdupq_n_s32(zp); + int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); + int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); + int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); + int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); + vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); + vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); +#else + for (int ci = 0; ci < C8NUM; ++ci) { + (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); + } +#endif + } + dst_oc_offset += ic8 * kernel_plane; + for (; i < input_channel; i++) { + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); + } + } + } +} + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { + /* normal matmul : 4x16 * 16x4 -> 4x4 */ +#ifdef ENABLE_ARM + PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); +#else + for (size_t r = 0; r < row4; r++) { + int32_t tmp_value = 0; + for (size_t c = 0; c < col16; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; + int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + tmp_value += src[src_index]; + } + dst[r] = tmp_value * filter_zp; + } +#endif + return; +} +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + const int8_t *src_b = src + b * unit * conv_param->input_channel_; + int16_t *dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + const int8_t *src_k = src_b + k * conv_param->input_channel_; + int16_t *dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)(src_k[c] - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_channel = c4 * C4NUM; + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const int8_t *src_batch = (const int8_t *)src + n * batch; + int8_t *dst_batch = (int8_t *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v8.4h, v4.4h, v6.4h\n" + "trn2 v9.4h, v4.4h, v6.4h\n" + "trn1 v10.4h, v5.4h, v7.4h\n" + "trn2 v11.4h, v5.4h, v7.4h\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "trn1 v12.4h, v4.4h, v6.4h\n" + "trn2 v13.4h, v4.4h, v6.4h\n" + "trn1 v14.4h, v5.4h, v7.4h\n" + "trn2 v15.4h, v5.4h, v7.4h\n" + + "trn1 v0.2s, v8.2s, v12.2s\n" + "trn2 v4.2s, v8.2s, v12.2s\n" + "trn1 v1.2s, v10.2s, v14.2s\n" + "trn2 v5.2s, v10.2s, v14.2s\n" + "trn1 v2.2s, v9.2s, v13.2s\n" + "trn2 v6.2s, v9.2s, v13.2s\n" + "trn1 v3.2s, v11.2s, v15.2s\n" + "trn2 v7.2s, v11.2s, v15.2s\n" + + "st1 {v0.8b}, [x11], %[dstStride]\n" + "st1 {v1.8b}, [x11], %[dstStride]\n" + "st1 {v2.8b}, [x11], %[dstStride]\n" + "st1 {v3.8b}, [x11], %[dstStride]\n" + "st1 {v4.8b}, [x11], %[dstStride]\n" + "st1 {v5.8b}, [x11], %[dstStride]\n" + "st1 {v6.8b}, [x11], %[dstStride]\n" + "st1 {v7.8b}, [x11], %[dstStride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#elif ENABLE_ARM32 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.8 {d0}, [r10], %[srcStride]\n" + "vld1.8 {d1}, [r10], %[srcStride]\n" + "vld1.8 {d2}, [r10], %[srcStride]\n" + "vld1.8 {d3}, [r10], %[srcStride]\n" + "vld1.8 {d4}, [r10], %[srcStride]\n" + "vld1.8 {d5}, [r10], %[srcStride]\n" + "vld1.8 {d6}, [r10], %[srcStride]\n" + "vld1.8 {d7}, [r10], %[srcStride]\n" + + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vst1.8 {d0}, [r12], %[dstStride]\n" + "vst1.8 {d1}, [r12], %[dstStride]\n" + "vst1.8 {d2}, [r12], %[dstStride]\n" + "vst1.8 {d3}, [r12], %[dstStride]\n" + "vst1.8 {d4}, [r12], %[dstStride]\n" + "vst1.8 {d5}, [r12], %[dstStride]\n" + "vst1.8 {d6}, [r12], %[dstStride]\n" + "vst1.8 {d7}, [r12], %[dstStride]\n" + : + : [dst_ptr] "r"(dst_ptr), [src_ptr] "r"(src_ptr), [srcStride] "r"(srcStride), [dstStride] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const int8_t *src_ptr = src_batch + hw * channel; + int8_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h new file mode 100644 index 00000000..b89b74c9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pack_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PACK_INT8_H_ +#define NNACL_INT8_PACK_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param); +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, const ConvParameter *conv_param); +#ifdef ENABLE_ARM +void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); +void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, const int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, + size_t oc_res, size_t stride); +#endif + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c new file mode 100644 index 00000000..cf042b6d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num) { + if (thread_num == 0) { + return NNACL_ERR; + } + int32_t copy_size = in_dims[3]; + for (int n = 0; n < in_dims[0]; n++) { + for (int h = tid; h < in_dims[1]; h += thread_num) { + for (int w = 0; w < in_dims[2]; w++) { + const int8_t *in = in_data + Offset(in_dims, n, h, w, 0); + int8_t *out = out_data + Offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); + memcpy(out, in, (size_t)copy_size * sizeof(int8_t)); + } + } + } + return NNACL_OK; +} + +int TransOut2InputDimIndexInt8(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + return MSMAX(index_sum - out_dim_index, offset); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndexInt8(int out_flatten_index, const int32_t *input_shape, int mirror_offset, + const int *in_strides, const int *out_strides, const int *paddings) { + int in_flatten_index = 0; + int i; + for (i = 0; i < COMM_SHAPE_SIZE; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndexInt8(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end) { + for (int i = begin; i < end; ++i) { + out[i] = in[GetInputFlattenIndexInt8(i, input_shape, mirror_offset, in_strides, out_strides, paddings)]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h new file mode 100644 index 00000000..08fe53c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pad_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PAD_INT8_H_ +#define NNACL_INT8_PAD_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num); +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c new file mode 100644 index 00000000..3f4de98d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.c @@ -0,0 +1,516 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/pooling_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/errorcode.h" + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + if (real_count == 0) { + return NNACL_ERR; + } + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int8_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = real_out; + } // in_channel loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int c16 = channel / C16NUM; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int input_zp = quant_args[0][0].zp_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = quant_args[0][0].scale_ / quant_args[1][0].scale_; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + int input_stride = (in_h_index * in_w + in_w_index) * channel; + int kw_s = MSMAX(0, -in_w_index); + int kw_e = MSMIN(win_w, in_w - in_w_index); + int kh_s = MSMAX(0, -in_h_index); + int kh_e = MSMIN(win_h, compute_args->input_h_ - in_h_index); + int real_count = (kw_e - kw_s) * (kh_e - kh_s); + if (real_count == 0) { + return NNACL_ERR; + } + + // 16 channels + for (int j = 0; j < c16; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg[2]; + tmp_avg[0] = vmovq_n_s16(0); + tmp_avg[1] = vmovq_n_s16(0); +#else + int16_t tmp_avg[16]; + int16_t real_out[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_avg[m] = 0; + } +#endif + int in_channel_offset = in_batch_offset + j * C16NUM; + int out_channel_offset = out_plane_offset + j * C16NUM; + + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset); + int8x8_t in_data1 = vget_low_s8(in_ptr); + int8x8_t in_data2 = vget_high_s8(in_ptr); + int16x8_t data1 = vmovl_s8(in_data1); + int16x8_t data2 = vmovl_s8(in_data2); + tmp_avg[0] = vaddq_s16(tmp_avg[0], data1); + tmp_avg[1] = vaddq_s16(tmp_avg[1], data2); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + int16_t tmp_data1[8]; + int16_t tmp_out1[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[0][l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + for (int l = 0; l < C8NUM; l++) { + tmp_data1[l] = tmp_avg[1][l] + 128 * real_count; + tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count; + tmp_out1[l] -= 128; + tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out[2]; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out[0] = vqmovn_s16(vld1q_s16(tmp_out)); + real_out[0] = vmin_s8(real_out[0], output_max); + real_out[0] = vmax_s8(real_out[0], output_min); + vst1_s8(output_ptr + out_channel_offset, real_out[0]); + real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1)); + real_out[1] = vmin_s8(real_out[1], output_max); + real_out[1] = vmax_s8(real_out[1], output_min); + vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]); +#else + for (int l = 0; l < C16NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // 8 channels + int channel_16_res = channel - c16 * C16NUM; + int c8 = channel_16_res / C8NUM; + int in_c16_offset = in_batch_offset + c16 * C16NUM; + int out_c16_offset = out_plane_offset + c16 * C16NUM; + for (int j = 0; j < c8; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg = vmovq_n_s16(0); +#else + int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + int16_t real_out[8]; +#endif + int in_channel_offset = in_c16_offset + j * C8NUM; + int out_channel_offset = out_c16_offset + j * C8NUM; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x8_t in_ptr = vld1_s8(input_ptr + in_offset); + int16x8_t data = vmovl_s8(in_ptr); + tmp_avg = vaddq_s16(tmp_avg, data); +#else + for (int k = 0; k < C8NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out = vqmovn_s16(vld1q_s16(tmp_out)); + real_out = vmin_s8(real_out, output_max); + real_out = vmax_s8(real_out, output_min); + vst1_s8(output_ptr + out_channel_offset, real_out); +#else + for (int l = 0; l < C8NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // less than 8 channel + int channel_8_res = channel_16_res - c8 * C8NUM; + int in_c8_offset = in_c16_offset + c8 * C8NUM; + int out_c8_offset = out_c16_offset + c8 * C8NUM; + for (int k = 0; k < channel_8_res; k++) { + int in_channel_offset = in_c8_offset + k; + int out_channel_offset = out_c8_offset + k; + int16_t tmp_avg = 0; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; + tmp_avg += input_ptr[in_offset]; + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128; + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } + return NNACL_OK; +} + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int c16 = UP_DIV(channel, 16); + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_max[m] = INT8_MIN; + } +#endif + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + for (int l = 0; l < C16NUM; ++l) { + tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C16NUM; ++l) { + *(output_ptr + out_channel_offset + l) = + (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } +#endif + } // in_channel loop + + // res channel + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = MSMIN(out_tile_count, thread_num); + int8_t out_array[MAX_MAXPOOL_SIZE]; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = out_plane - cal_start_index; + real_cal_num = MSMIN(real_cal_num, TILE_NUM); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); + int ky_e = MSMIN(compute_args->window_h_, compute_args->input_h_ - in_h_index); + const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); + int kx_e = MSMIN(compute_args->window_w_, in_w - in_w_index); + int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; + int out_plane_offset = out_batch_offset + index * channel; + + int c = 0; + for (; c < channel; c += MAX_MAXPOOL_SIZE) { + int real_channel = channel - c; + real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE); + memset(out_array, INT8_MIN, real_channel); + int8_t *out_data = output_ptr + out_plane_offset + c; + for (int h = ky_s; h < ky_e; ++h) { + int in_h_offset = input_stride + h * in_w * channel + c; + for (int w = kx_s; w < kx_e; ++w) { + const int8_t *in_data = input_ptr + in_h_offset + w * channel; + int j = 0; +#ifdef ENABLE_NEON + const int8_t *tmp_in_data = in_data; + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + for (; j < c16; j += 16) { + int8x16_t ori_in = vld1q_s8(tmp_in_data); + int8x16_t out_array16 = vld1q_s8(out_array + j); + tmp_in_data += 16; + out_array16 = vmaxq_s8(ori_in, out_array16); + vst1q_s8(out_array + j, out_array16); + } // 16 channel loop + + for (; j < c8; j += 8) { + int8x8_t ori_in = vld1_s8(tmp_in_data); + int8x8_t out_array8 = vld1_s8(out_array + j); + tmp_in_data += 8; + out_array8 = vmax_s8(ori_in, out_array8); + vst1_s8(out_array + j, out_array8); + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j]; + } + } // kw loop + } // kh loop + + int j = 0; +#ifdef ENABLE_NEON + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + int8_t *tmp_out_data = out_data; + for (; j < c16; j += 16) { + vst1q_s8(tmp_out_data, vld1q_s8(out_array + j)); + tmp_out_data += 16; + } // 16 channel loop + + for (; j < c8; j += 8) { + vst1_s8(tmp_out_data, vld1_s8(out_array + j)); + tmp_out_data += 8; + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_data[j] = out_array[j]; + } + } // 256 channel loop + } // out_plane loop + } // out_batch loop + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h new file mode 100644 index 00000000..3b1be326 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/pooling_int8.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_POOLING_H_ +#define NNACL_INT8_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define MAX_MAXPOOL_SIZE 256 + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c new file mode 100644 index 00000000..3fc6a60c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/power_int8.h" + +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift) { + double input_scale = args->in_args_.scale_; + int input_zp = args->in_args_.zp_; + double output_scale = args->out_args_.scale_; + int output_zp = args->out_args_.zp_; + int act_min = args->output_activation_min_; + int act_max = args->output_activation_max_; + double exp_scale = args->exp_args_.scale_; + int exp_zp = args->exp_args_.zp_; + + if (broadcast) { + float exp_val = exp_scale * (exp_ptr[0] - exp_zp); + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } else { + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float exp_val = exp_scale * (exp_ptr[i] - exp_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h new file mode 100644 index 00000000..7bf8b809 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/power_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_POWER_INT8_H_ +#define NNACL_INT8_POWER_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POWER_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c new file mode 100644 index 00000000..edd05f50 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.c @@ -0,0 +1,437 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM64 +inline void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[dst]], #64\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "str s9, [%[dst]], #4\n" + "bne 1b\n" + + "2:\n" + + : + : [quant_values] "r"(quant_values), [dst] "r"(dst), [scale] "r"(scale), [zp32] "r"(zp), [size] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + +#ifdef ENABLE_ARM64 + Int8ToFp32_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; i++) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + float ivs = 1.0f / scale; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v12.4s, %w[ivs]\n" + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [quant_values] "r"(quant_values), [real_values] "r"(real_values), [scale] "r"(scale), [zp] "r"(zp), + [size] "r"(size), [ivs] "r"(ivs), [min_value] "r"(min_value), [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14"); +} +#endif + +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8_arm64(real_values, quant_values, scale, zp, size, min_value, max_value); +#else + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, + int size, int row_length, int32_t min_value, int32_t max_value) { + volatile float ivs[size]; + for (int i = 0; i < size; i++) { + volatile int channel_index = i / row_length; + ivs[i] = 1.0f / scales[channel_index]; + } + volatile int32_t zp = zps[0]; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "mov x4, %[ivs]\n" // reload ivs + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [quant_values] "r"(quant_values), [real_values] "r"(real_values), [scales] "r"(scales), [zp] "r"(zp), + [size] "r"(size), [row_length] "r"(row_length), [ivs] "r"(ivs), [min_value] "r"(min_value), + [max_value] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "x4"); +} +#endif + +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8Perchannel_arm64(real_values, quant_values, scale, zp, size, row_length, min_value, max_value); +#else + for (int i = 0; i < size; ++i) { + int channel_index = i / row_length; + const float inverse_scale = 1.0f / scale[channel_index]; + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp[channel_index]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } + int row_total = size / row_length; + for (int r = 0; r < row_total; r++) { + const float *real_current = real_values + r * row_length; + int8_t *quant_current = quant_values + r * row_length; + for (int c = 0; c < row_length; c++) { + const float inverse_scale = 1.0f / scale[c]; + if (real_current[c] == INFINITY) { + quant_current[c] = max_value; + } else if (real_current[c] == -INFINITY) { + quant_current[c] = min_value; + } else { + int temp = round(real_current[c] * inverse_scale + zp[c]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_current[c] = (int8_t)temp; + } + } + } + return NNACL_OK; +} + +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + zp += 128; + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp -= 128; + temp = temp < 127 ? temp : 127; + temp = temp > -128 ? temp : -128; + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (float)((int)quant_values[i] - zp) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf(real_values[i])) { + quant_values[i] = 255; + } else { + float temp = (float)round(real_values[i] * 1.0 / scale + zp); + if (temp > 255) { + quant_values[i] = 255; + } else if (temp < 0) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + } + return NNACL_OK; +} + +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = quant_values[i] + 128; + if (temp > 255) { + real_values[i] = (uint8_t)255; + } else if (temp < 0) { + real_values[i] = 0; + } else { + real_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} + +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = (int)real_values[i] - 128; + if (temp > 127) { + quant_values[i] = 127; + } else if (temp < -128) { + quant_values[i] = -128; + } else { + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h new file mode 100644 index 00000000..6de46c60 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quant_dtype_cast_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_QUANTDTYPECAST_H_ +#define NNACL_INT8_QUANTDTYPECAST_H_ + +#include "nnacl_c/op_base.h" + +typedef struct QuantDTypeCastParameter { + OpParameter op_parameter_; + int32_t srcT; + int32_t dstT; + int32_t axis; +} QuantDTypeCastParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value); +#ifdef ENABLE_ARM64 +void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size); +void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, int size, + int row_length, int32_t min_value, int32_t max_value); +#endif +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c new file mode 100644 index 00000000..e9eadb23 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/quantize.h" +#include + +const uint64_t dSignMask = 1ull << 63; +const uint64_t dExponentMask = 0x7ffull << 52; +const uint64_t dFractionMask = (1ull << 52) - 1; +const int dExponentBias = 1022; +const int dMantissaBits = 52; +const int dInfiniteExponent = 0x7ff; +const double dNormalizer = 0x1p54; +const int dNormalizerBias = 54; +const int iMantissaBits = 31; + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift) { + if (quantized_multiplier == NULL || right_shift == NULL) { + return; + } + int shift = 0; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + *right_shift = -shift; +} + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + const uint32_t scale_bits = (uint32_t)(double_multiplier); + /* multiplier is in[0x40000000, 0x7FFFFF80] range */ + *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { + return; + } + /* shift is in [0, 31] range */ + shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi) { + int32_t min = INT8_MIN; + int32_t max = INT8_MAX; + int32_t quantized_zero = QuantizeToInt8(0, scale, zp); + int32_t quantized_six = QuantizeToInt8(6, scale, zp); + if (is_relu) { + min = min > quantized_zero ? min : quantized_zero; + } else if (is_relu6) { + min = min > quantized_zero ? min : quantized_zero; + max = max < quantized_six ? max : quantized_six; + } else { + // do nothing + } + *mini = min; + *maxi = max; +} + +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data) { + for (int i = 0; i < length; ++i) { + int q = (int)round(input_data[i] / scale + zero_point); + q = q > SCHAR_MAX ? SCHAR_MAX : q; + q = q < SCHAR_MIN ? SCHAR_MIN : q; + output_data[i] = (int8_t)q; + } +} + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data) { + for (int i = 0; i < length; ++i) { + output_data[i] = scale * (input_data[i] - zero_point); + } +} + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift) { + if (quantized_multiplier == NULL || shift == NULL) { + return; + } + // we split a floating number into two parts: exponent and fraction + // since fraction is stored as int32, only 31 bits of mantissa is remained + union { + double d; + uint64_t ul; + } dul; + dul.d = double_multiplier; + if (!(dul.ul & (~dSignMask))) { + // multiplier is 0 + *quantized_multiplier = 0; + *shift = 0; + return; + } + int exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + if (exponent == dInfiniteExponent) { + // multiplier is inf or NaN + *shift = 0; + if (!(dul.ul & dFractionMask)) { + // inf + *quantized_multiplier = (dul.ul & dSignMask) ? INT_MIN : INT_MAX; + } else { + // NaN + *quantized_multiplier = 0; + } + return; + } + if (exponent == 0) { + // multiplier is a subnormal number + dul.d *= dNormalizer; + exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + *shift = exponent - dExponentBias - dNormalizerBias; + } else { + *shift = exponent - dExponentBias; + } + uint64_t fraction = dul.ul & dFractionMask; + fraction += (1ull << dMantissaBits); + uint64_t rounded = ((fraction >> (dMantissaBits - iMantissaBits)) + 1ull) >> 1; + // we get 31 rounded bits now + if (rounded == (1ull << iMantissaBits)) { + // rounding may cause a carry + rounded >>= 1; + ++*shift; + } + *quantized_multiplier = (dul.ul & dSignMask) ? (-(int32_t)(rounded)) : (int32_t)(rounded); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h new file mode 100644 index 00000000..380713c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/quantize.h @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_QUANTIZE_H_ +#define NNACL_QUANTIZATION_QUANTIZE_H_ + +#include +#include "nnacl_c/op_base.h" + +#define INPUT_PER_CHANNEL 0b001 +#define FILTER_PER_CHANNEL 0b010 +#define OUTPUT_PER_CHANNEL 0b100 + +typedef struct ConvQuantArg { + RoundingMode round_mode_; + CalFixedMultiplierMode quant_multiplier_mode_; + QuantArg *input_quant_args_; + QuantArg *filter_quant_args_; + QuantArg *output_quant_args_; + double *real_multiplier_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; + int32_t *out_act_min_; + int32_t *out_act_max_; + size_t input_arg_num_; + size_t filter_arg_num_; + size_t output_arg_num_; + uint8_t per_channel_; +} ConvQuantArg; + +typedef struct ConcatQuantArg { + QuantArg *in_args_; + QuantArg out_args_; + int8_t output_activation_min_; + int8_t output_activation_max_; +} ConcatQuantArg; + +typedef struct PreluQuantArg { + int32_t *input_sizes_; + int output_size_; + int32_t **input_shapes_; + int32_t *output_shape_; + size_t input_num_; + size_t output_dim_; + float alpha_; + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantArg *in_quant_args_; + QuantArg out_quant_args_; +} PreluQuantArg; + +typedef struct CropQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} CropQuantArg; + +typedef struct ArithSelfQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} ArithSelfQuantArg; + +typedef struct GatherQuantArg { + double alpha_; + int zp_in_; + int zp_out_; +} GatherQuantArg; + +typedef struct DynamicGatherQuantArg { + float *scale_in_; + int32_t *zp_in_; +} DynamicGatherQuantArg; + +typedef struct SoftmaxQuantArg { + QuantArg in_quant_args_; + QuantArg out_quant_arg_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} SoftmaxQuantArg; + +typedef struct SubQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +} SubQuantArg; + +typedef struct ArithmeticQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; +} ArithmeticQuantArg; + +typedef struct DivQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int output_shift_; +} DivQuantArg; + +typedef struct ReduceQuantArg { + double in_scale_; + int32_t in_zp_; + double out_scale_; + int32_t out_zp_; + int32_t in_out_multiplier_; + int in_out_left_shift_; + int in_out_right_shift_; + int32_t mean_multiplier_; + int mean_left_shift_; + int mean_right_shift_; + int32_t prod_multiplier_; + int prod_left_shift_; + int prod_right_shift_; + int32_t sum_square_multiplier_; + int sum_square_left_shift_; + int sum_square_right_shift_; +} ReduceQuantArg; + +typedef struct LeakyReluQuantArg { + QuantArg in_args_; + QuantArg out_args_; + float slope_; + int input_dim_; + int element_num; + int thread_num_; +} LeakyReluQuantArg; + +typedef struct ResizeQuantArg { + int32_t ratio_x_; + int32_t ratio_y_; + int32_t *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + int32_t *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeQuantArg; + +typedef struct ResizeFloatScaleQuantArg { + float ratio_x_; + float ratio_y_; + float *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + float *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeFloatScaleQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift); + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift); + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp); + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi); +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data); + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_QUANTIZATION_QUANTIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c new file mode 100644 index 00000000..be1aca7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.c @@ -0,0 +1,597 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl_c/int8/reduce_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/common_func.h" + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias) { + int stride = plane * UP_ROUND(c, C4NUM); + for (int batch = 0; batch < n; ++batch) { + int8_t *in_ptr = in_data + batch * stride; + int8_t *out_ptr = out_data + batch * c; + for (int i = 0; i < count; ++i) { + int32_t sum_array = 0; + int j = 0; +#ifdef ENABLE_ARM64 + for (; j < plane; j += 16) { + int8x16_t in_data_vec = vld1q_s8(in_ptr); + sum_array += vaddlvq_s8(in_data_vec); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int8x8_t in_data_vec = vld1_s8(in_ptr); + sum_array += vaddlv_s8(in_data_vec); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + sum_array += vaddvq_s32(in_data_vec); + in_ptr += 4; + } +#elif ENABLE_ARM32 + int32x4_t accum = vmovq_n_s32(0); + for (; j < plane; j += 16) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + int32x4_t in_data_vec3; + int32x4_t in_data_vec4; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + in_data_vec3[0] = in_ptr[8]; + in_data_vec3[1] = in_ptr[9]; + in_data_vec3[2] = in_ptr[10]; + in_data_vec3[3] = in_ptr[11]; + in_data_vec4[0] = in_ptr[12]; + in_data_vec4[1] = in_ptr[13]; + in_data_vec4[2] = in_ptr[14]; + in_data_vec4[3] = in_ptr[15]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + accum = vaddq_s32(accum, in_data_vec3); + accum = vaddq_s32(accum, in_data_vec4); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + accum = vaddq_s32(accum, in_data_vec); + in_ptr += 4; + } + sum_array += accum[0]; + sum_array += accum[1]; + sum_array += accum[2]; + sum_array += accum[3]; +#endif + for (; j < plane; j++) { + sum_array += in_ptr[0]; + in_ptr++; + } + int32_t mean = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum_array * (1 << (unsigned int)quant_arg.left_shift_), + quant_arg.multiplier_), + quant_arg.right_shift_); + mean += bias; + *out_ptr++ = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = mean +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + if (isAddOverflow(mean, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = mean + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = mean(x-zp_in)*scale_in +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + // trans to output scale + int32_t mean_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(mean * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(mean_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + mean = mean_scaled + quant->out_zp_; + + *inner_dst = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = sum(item-zp_in)*scale_in +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + + if (isAddOverflow(quant->in_zp_, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = sum(item-zp_in)*scale_in +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul((tmp - quant->in_zp_) * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + const int base_offset = 20; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + (tmp - quant->in_zp_) * (1 << ((unsigned int)quant->in_out_left_shift_ + base_offset)), + quant->in_out_multiplier_), + quant->in_out_right_shift_ + base_offset); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + int32_t prod_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(prod_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + prod = prod_scaled + quant->out_zp_; + if (prod > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (prod < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)prod; + } + } + } + return NNACL_OK; +} + +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + if (isAddOverflow(prod, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = prod + quant->in_zp_; + } + } + return NNACL_OK; +} + +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + sum = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h new file mode 100644 index 00000000..f8302ae8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reduce_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_REDUCE_INT8_H_ +#define NNACL_INT8_REDUCE_INT8_H_ + +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias); +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); + +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_INT8_REDUCE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c new file mode 100644 index 00000000..843283b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/relux_int8.h" + +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg) { + for (int i = 0; i < length; ++i) { + if (src[i] <= arg->input_arg.zp_) { + dst[i] = arg->output_arg.zp_; + continue; + } + const int32_t input_val = src[i] - arg->input_arg.zp_; + const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_); + const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1U << arg->left_shift_), -arg->right_shift_); + const int32_t output = shifted_input + arg->output_arg.zp_; + dst[i] = (int8_t)MSMIN(output, arg->quantized_output_max); + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h new file mode 100644 index 00000000..0676cf10 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/relux_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RELU_INT8_H_ +#define NNACL_INT8_RELU_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct ReluXQuantArg { + QuantArg input_arg; + QuantArg output_arg; + int input_multiplier_; + int left_shift_; + int right_shift_; + int quantized_output_min; + int quantized_output_max; +} ReluXQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RELU_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c new file mode 100644 index 00000000..4b3fc200 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/reshape_int8.h" +#include "nnacl_c/reshape_parameter.h" +#include + +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para) { + if (para.in_args_.scale_ == para.out_args_.scale_ && para.in_args_.zp_ == para.out_args_.zp_) { + memcpy(output_ptr, input_ptr, real_dst_count); + } else { + const float output_inverse_scale = 1.f / para.out_args_.scale_; + float scale = para.in_args_.scale_ * output_inverse_scale; + float bias = -para.in_args_.zp_ * scale; + int32_t output_zp = para.out_args_.zp_; + for (int i = 0; i < real_dst_count; i++) { + int32_t output_tmp = round(input_ptr[i] * scale + bias) + output_zp; + if (output_tmp > para.output_activation_max_) { + output_ptr[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output_ptr[i] = para.output_activation_min_; + } else { + output_ptr[i] = (int8_t)output_tmp; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h new file mode 100644 index 00000000..46fb480c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/reshape_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_RESHAHPE_INT8_H_ +#define NNACL_INT8_RESHAHPE_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESHAHPE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c new file mode 100644 index 00000000..4f7f07f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.c @@ -0,0 +1,233 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl_c/int8/resize_int8.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/errorcode.h" + +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + int32_t weight_x = quant_arg.x_axis_index_[ori_out_w] - (1 << 10) * x_lower_value; + int32_t one_minus_weight_x = (1 << 10) - weight_x; + int32_t weight_y = quant_arg.y_axis_index_[ori_out_h] - (1 << 10) * y_lower_value; + int32_t one_minus_weight_y = (1 << 10) - weight_y; + int64_t left_bottom_coef = (int64_t)(one_minus_weight_x * one_minus_weight_y); + int64_t left_top_coef = (int64_t)(weight_y * one_minus_weight_x); + int64_t right_bottom_coef = (int64_t)(weight_x * one_minus_weight_y); + int64_t right_top_coef = (int64_t)(weight_x * weight_y); + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; + for (; c < channel; c++) { + int64_t out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + int64_t out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + int64_t out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + int64_t out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + int64_t out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)((out_value + (1 << 19)) / (1 << 20)); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + float weight_x = quant_arg.x_axis_index_[ori_out_w] - x_lower_value; + const float one_minus_weight_x = 1 - weight_x; + float weight_y = quant_arg.y_axis_index_[ori_out_h] - y_lower_value; + const float one_minus_weight_y = 1 - weight_y; + float left_bottom_coef = one_minus_weight_x * one_minus_weight_y; + float left_top_coef = weight_y * one_minus_weight_x; + float right_bottom_coef = weight_x * one_minus_weight_y; + float right_top_coef = weight_x * weight_y; + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= channel - 4; c += 4) { + float32x4_t in_lb; + in_lb[0] = (float)in_b_ptr[input_lb_index]; + in_lb[1] = (float)in_b_ptr[input_lb_index + 1]; + in_lb[2] = (float)in_b_ptr[input_lb_index + 2]; + in_lb[3] = (float)in_b_ptr[input_lb_index + 3]; + float32x4_t out_left_bottom = vmulq_n_f32(in_lb, left_bottom_coef); + float32x4_t in_lt; + in_lt[0] = (float)in_b_ptr[input_lt_index]; + in_lt[1] = (float)in_b_ptr[input_lt_index + 1]; + in_lt[2] = (float)in_b_ptr[input_lt_index + 2]; + in_lt[3] = (float)in_b_ptr[input_lt_index + 3]; + float32x4_t out_left_top = vmulq_n_f32(in_lt, left_top_coef); + float32x4_t in_rb; + in_rb[0] = (float)in_b_ptr[input_rb_index]; + in_rb[1] = (float)in_b_ptr[input_rb_index + 1]; + in_rb[2] = (float)in_b_ptr[input_rb_index + 2]; + in_rb[3] = (float)in_b_ptr[input_rb_index + 3]; + float32x4_t out_right_bottom = vmulq_n_f32(in_rb, right_bottom_coef); + float32x4_t in_rt; + in_rt[0] = (float)in_b_ptr[input_rt_index]; + in_rt[1] = (float)in_b_ptr[input_rt_index + 1]; + in_rt[2] = (float)in_b_ptr[input_rt_index + 2]; + in_rt[3] = (float)in_b_ptr[input_rt_index + 3]; + float32x4_t out_right_top = vmulq_n_f32(in_rt, right_top_coef); + float32x4_t out_value1 = vaddq_f32(out_left_bottom, out_left_top); + float32x4_t out_value2 = vaddq_f32(out_right_top, out_right_bottom); + float32x4_t out_value = vaddq_f32(out_value1, out_value2); + out_b_ptr[0] = (int8_t)(out_value[0]); + out_b_ptr[1] = (int8_t)(out_value[1]); + out_b_ptr[2] = (int8_t)(out_value[2]); + out_b_ptr[3] = (int8_t)(out_value[3]); + input_lb_index += 4; + input_lt_index += 4; + input_rb_index += 4; + input_rt_index += 4; + out_b_ptr += 4; + } +#endif + for (; c < channel; c++) { + float out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + float out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + float out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + float out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + float out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)(out_value); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num) { + int batch, y, x, c; + c = output_shape[3]; + int in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(int8_t)); + } + } + } + + return NNACL_OK; +} + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest) { + if (new_size == 0) { + return; + } + *nearest = (in_size * pos) / new_size; + if (align_corners && new_size != 1) { + *nearest = ((in_size - 1) * pos + (new_size - 1) / 2) / (new_size - 1); + } + *nearest = *nearest < in_size ? *nearest : in_size - 1; +} + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num) { + const int base_offset = 20; + int32_t batch, y, x, c; + int32_t in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + for (c = 0; c < output_shape[3]; c++) { + int in_offset = Offset(input_shape, batch, input_y, input_x, c); + int out_offset = Offset(output_shape, batch, y, x, c); + + int32_t out_value = MultiplyByQuantizedMultiplier( + input_data[in_offset] - quant_in->zp_, multiplier->multiplier_, + multiplier->left_shift_ + base_offset, multiplier->right_shift_ - base_offset) + + quant_out->zp_; + out_value = out_value > INT8_MAX ? INT8_MAX : out_value; + out_value = out_value < INT8_MIN ? INT8_MIN : out_value; + output_data[out_offset] = (int8_t)out_value; + } + } + } + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h new file mode 100644 index 00000000..78e1972b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/resize_int8.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RESIZE_H_ +#define NNACL_INT8_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg); + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg); + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num); + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num); + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c new file mode 100644 index 00000000..23a508c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/scale_int8.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const ScaleQuantParameter *scale_param) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2, + const ScaleQuantParameter *scale_param) { + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)(scale_param->scale_mul_arg_.left_shift_)); + int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << (size_t)(scale_param->offset_mul_arg_.left_shift_)); + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2), + scale_param->offset_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vaddq_s32(raw_sum, raw_sum2); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} +#endif + +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); + + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param); + int16x4_t sum_high = + ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = scale_param->input_zp_ + in_data[index]; + const int32_t input1_val = scale_param->scale_zp_ + scale[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + + mul_result += scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} + +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int8x8_t input2_s8 = vld1_s8(offset + index); + int16x8_t input2_s16 = vmovl_s8(input2_s8); + int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val)); + int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val)); + + int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param); + int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = in_data[index] - scale_param->input_zp_; + const int32_t input1_val = scale[index] - scale_param->scale_zp_; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + int tmp_bias = offset[index] - scale_param->offset_zp_; + int bias = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), + scale_param->offset_mul_arg_.multiplier_), + scale_param->offset_mul_arg_.right_shift_); + + mul_result += bias + scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h new file mode 100644 index 00000000..02cace65 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/scale_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_INT8_H_ +#define NNACL_SCALE_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count); +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_SCALE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c new file mode 100644 index 00000000..6f892602 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/sigmoid_int8.h" + +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table) { + for (int i = 0; i < length; i++) { + const int8_t input_value = src[i]; + uint8_t index = (uint8_t)input_value; + dst[i] = table[index]; + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h new file mode 100644 index 00000000..5fc8db7b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sigmoid_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SIGMOID_INT8_H_ +#define NNACL_INT8_SIGMOID_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SIGMOID_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c new file mode 100644 index 00000000..5ae5a4f5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/slice_int8.h" +#include +#include +#include "nnacl_c/errorcode.h" + +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num) { + double input_scale = quant_arg->in_args_.scale_; + int input_zp = quant_arg->in_args_.zp_; + double output_scale = quant_arg->out_args_.scale_; + int output_zp = quant_arg->out_args_.zp_; + const int base_offset = 20; + int act_min = quant_arg->output_activation_min_; + int act_max = quant_arg->output_activation_max_; + + size_t out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + + int count_per_thread = UP_DIV(param->size_[5], thread_num); + size_t thread_begin = thread_id * count_per_thread; + size_t thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int unit_size = param->size_[7] * sizeof(int8_t); + size_t in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + int i, j, k, l, n, h, w, c; + + int equal_quant = 0; + if (fabs(input_scale - output_scale) <= FLT_EPSILON && input_zp == output_zp) { + equal_quant = 1; + } + + for (i = 0; i < param->size_[0]; ++i) { + size_t out_offset0 = i * out_stride[0]; + size_t in_offset0 = (i + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (j = 0; j < param->size_[1]; ++j) { + size_t out_offset1 = j * out_stride[1] + out_offset0; + size_t in_offset1 = (j + param->begin_[1]) * in_stride[1] + in_offset0; + for (k = 0; k < param->size_[2]; ++k) { + size_t out_offset2 = k * out_stride[2] + out_offset1; + size_t in_offset2 = (k + param->begin_[2]) * in_stride[2] + in_offset1; + for (l = 0; l < param->size_[3]; ++l) { + size_t out_offset3 = l * out_stride[3] + out_offset2; + size_t in_offset3 = (l + param->begin_[3]) * in_stride[3] + in_offset2; + for (n = 0; n < param->size_[4]; ++n) { + size_t out_offset4 = n * out_stride[4] + out_offset3; + size_t in_offset4 = (n + param->begin_[4]) * in_stride[4] + in_offset3; + for (h = thread_begin; h < thread_end; ++h) { + size_t out_offset5 = h * out_stride[5] + out_offset4; + size_t in_offset5 = (h + param->begin_[5]) * in_stride[5] + in_offset4; + for (w = 0; w < param->size_[6]; ++w) { + size_t out_offset = w * out_stride[6] + out_offset5; + size_t in_offset = (w + param->begin_[6]) * in_stride[6] + in_offset5; + if (equal_quant == 1) { + memcpy(output + out_offset, input + in_offset, unit_size); + } else { + for (c = 0; c < param->size_[7]; ++c) { + int32_t output_val = + MultiplyByQuantizedMultiplier(input[in_offset + c] - input_zp, quant_arg->multiplier_.multiplier_, + quant_arg->multiplier_.left_shift_ + base_offset, + quant_arg->multiplier_.right_shift_ - base_offset) + + output_zp; + output_val = MSMAX(INT8_MIN, MSMIN(output_val, INT8_MAX)); + output[c + out_offset] = (int8_t)MSMAX(act_min, MSMIN(output_val, act_max)); + } + } + } + } + } + } + } + } + } + + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h new file mode 100644 index 00000000..7ace3db0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/slice_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SLICE_INT8_H_ +#define NNACL_INT8_SLICE_INT8_H_ + +#include +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SLICE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c new file mode 100644 index 00000000..7d32f06a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/softmax_int8.h" +#include "nnacl_c/errorcode.h" + +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param) { + int axis_shape_size = input_shape[axis]; + int inner_size = 1; + if (n_dim > DIMENSION_5D) { + return NNACL_ERR; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int o = 0; o < count; o++) { + int outter_offset = o * axis_shape_size * inner_size; + + for (int c = 0; c < inner_size; c++) { + int8_t max_row = quant_param->output_activation_min_; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + max_row = MSMAX(max_row, input_ptr[axis_offset]); + } + + int32_t exp_sum = 0; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + const int32_t input_val = input_ptr[axis_offset] - max_row; + const int32_t input_scaled = SaturatingRoundingDoublingHighMul( + input_val * (1 << (unsigned int)quant_param->shift_left_), quant_param->output_multiplier_); + int exp_val = exp_on_negative_values(input_scaled, 5); + exp_data[axis_offset] = exp_val; + exp_sum = exp_sum + Rescale(exp_val, 0, 12); + } + sum_data[c] = exp_sum; + } + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + i * inner_size; + for (int c = 0; c < inner_size; ++c) { + int num_bits_over_unit; + int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit); + int unsat_output = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); + + int raw_output = unsat_output + quant_param->output_activation_min_; + output_ptr[axis_offset + c] = + (int8_t)MSMAX(quant_param->output_activation_min_, MSMIN(raw_output, quant_param->output_activation_max_)); + } + } + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h new file mode 100644 index 00000000..6f538af2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/softmax_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SOFTMAX_INT8_H_ +#define NNACL_INT8_SOFTMAX_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c new file mode 100644 index 00000000..f5bba63b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.c @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/int8/space_to_batch_int8.h" +#include "nnacl_c/common_func.h" + +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape) { + int out_dim0 = out_shape[0]; + int out_dim1 = out_shape[1]; + int out_dim2 = out_shape[2]; + int copy_num = out_shape[3]; + int block_w = block_sizes[1]; + int block_h = block_sizes[0]; + int in_strides[4] = {0}; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4] = {0}; + ComputeStrides(out_shape, out_strides, 4); + size_t copy_size = copy_num * sizeof(int8_t); + size_t out_offset = 0; + + NNACL_CHECK_ZERO_RETURN(in_shape[0]); + NNACL_CHECK_ZERO_RETURN(block_w); + for (int n = 0; n < out_dim0; ++n) { + int in_n = n % in_shape[0]; + int32_t stride_w = (n / in_shape[0]) % block_w; + int32_t stride_h = (n / in_shape[0]) / block_w; + size_t in_offset0 = in_n * in_strides[0]; + for (int h = 0; h < out_dim1; ++h) { + size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1]; + for (int w = 0; w < out_dim2; ++w) { + size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2]; + memcpy(output + out_offset, input + in_offset2, copy_size); + out_offset += copy_num; + } + } + } +} + +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) { + int block_shape_h = param->block_sizes_[0]; + int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1; + int in_b = param->input_shape_[0]; + int in_h = param->input_shape_[1]; + int in_w = param->input_shape_[2]; + int channel = param->input_shape_[3]; + int out_h = param->output_shape_[1]; + int out_w = param->output_shape_[2]; + int pad_t = param->paddings_[0]; + int pad_l = param->m_ == 2 ? param->paddings_[2] : 0; + + NNACL_CHECK_ZERO_RETURN(in_b); + NNACL_CHECK_ZERO_RETURN(block_shape_w); + for (int i = 0; i < param->output_shape_[0]; ++i) { + int in_batch = i % in_b; + int offset_w = (i / in_b) % block_shape_w; + int offset_h = (i / in_b) / block_shape_w; + int in_b_offset = in_batch * in_h * in_w * channel; + int out_b_offset = i * out_h * out_w * channel; + for (int j = 0; j < out_h; ++j) { + int out_h_offset = out_b_offset + j * out_w * channel; + for (int k = 0; k < out_w; ++k) { + int8_t *out_ptr = output + out_h_offset + k * channel; + int index_h = j * block_shape_h + offset_h; + int index_w = k * block_shape_w + offset_w; + if (index_h < pad_t || index_h >= (pad_t + in_h) || index_w < pad_l || index_w >= (pad_l + in_w)) { + memset(out_ptr, zp, channel * sizeof(int8_t)); + } else { + int in_plane_offset = in_b_offset + ((index_h - pad_t) * in_w + (index_w - pad_l)) * channel; + const int8_t *in_ptr = input + in_plane_offset; + memcpy(out_ptr, in_ptr, channel * sizeof(int8_t)); + } + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h new file mode 100644 index 00000000..8d60bc1c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/space_to_batch_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SPACE_TO_BATCH_INT8_H_ +#define NNACL_INT8_SPACE_TO_BATCH_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape); +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPACE_TO_BATCH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c new file mode 100644 index 00000000..5cc53dfa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/split_int8.h" +#include +#include +#include +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/errorcode.h" + +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *param) { + if (in_data == NULL || out_data == NULL) { + return NNACL_ERR; + } + const int num_split = param->num_split_; + const int32_t *split_sizes = param->split_sizes_; + const int32_t *strides = param->strides_; + const int split_dim = param->split_dim_; + int in_stride = strides[split_dim]; + + int stride_per_split = in_stride * input_shape[split_dim]; + int split_which = offset % num_split; + int split_times = offset / num_split; + const int8_t *src = in_data + split_times * stride_per_split; + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + const QuantArg in_quant_arg = param->quant_arg_.in_args_; + float in_scale = in_quant_arg.scale_; + int32_t in_zp = in_quant_arg.zp_; + const QuantArg *out_quant_arg = param->quant_arg_.out_args_; + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int copy_size = split_sizes[split_which] * in_stride; + int8_t *dst = out_data[split_which] + split_times * copy_size; + float out_scale = out_quant_arg[split_which].scale_; + int32_t out_zp = out_quant_arg[split_which].zp_; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + (void)memcpy(dst, src, copy_size * sizeof(int8_t)); + } else { + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(src[j] * scale + bias) + out_zp; + if (output_tmp > param->quant_arg_.output_activation_max_) { + dst[j] = param->quant_arg_.output_activation_max_; + } else if (output_tmp < param->quant_arg_.output_activation_min_) { + dst[j] = param->quant_arg_.output_activation_min_; + } else { + dst[j] = (int8_t)output_tmp; + } + } + } + src += copy_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h new file mode 100644 index 00000000..8db63feb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/split_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SPLIT_INT8_H_ +#define NNACL_INT8_SPLIT_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *split_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPLIT_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c new file mode 100644 index 00000000..19b4b914 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.c @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/squeeze_int8.h" + +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count) { + float output_scale = quant_Squeeze_parm->out_quant_args_->scale_; + const float output_inverse_scale = 1.f / output_scale; + QuantArg *input_quant = quant_Squeeze_parm->in_quant_args_; + int output_zp = quant_Squeeze_parm->out_quant_args_->zp_; + + const int i = 0; + for (int j = task_id; j < num; j += thread_count) { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h new file mode 100644 index 00000000..4129a292 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/squeeze_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SQUEEZE_INT8_H_ +#define NNACL_INT8_SQUEEZE_INT8_H_ + +#include "nnacl_c/squeeze_parameter.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SQUEEZE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c new file mode 100644 index 00000000..f1548537 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.c @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/sub_int8.h" +#ifdef ENABLE_NEON +#include +#include "nnacl_c/int8/common_func_int8.h" +#endif +#include "nnacl_c/int8/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t DoClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const SubQuantArg *para) { + int32x4_t raw_data = vsubq_s32(scaled_input0, scaled_input1); + + raw_data = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_data, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_data = vaddq_s32(raw_data, vdupq_n_s32(para->out_args_.zp_)); + raw_data = vmaxq_s32(raw_data, vdupq_n_s32(para->output_activation_min_)); + raw_data = vminq_s32(raw_data, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_data); +} + +void SubInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para, int32_t *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << (size_t)para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->in0_args_.zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->in1_args_.zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + DoClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + DoClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data + *index, res_u8_n0); + } +} +#endif + +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para) { + int index = 0; +#ifdef ENABLE_NEON + SubInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_data = scaled_input0_val - scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->out_args_.zp_; + + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h new file mode 100644 index 00000000..e9107daa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/sub_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SUB_INT8_H_ +#define NNACL_INT8_SUB_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SUB_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c new file mode 100644 index 00000000..35171847 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/tanh_int8.h" +#ifdef ENABLE_NEON +#include +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant) { + for (int i = 0; i < size; ++i) { + float fp32_src = (input_ptr[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = TanhOpt(fp32_src); + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + output_ptr[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h new file mode 100644 index 00000000..332d7180 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/tanh_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TANH_INT8_H_ +#define NNACL_INT8_TANH_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/fixed_point.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/fp32/activation_fp32.h" + +typedef struct TanhQuantParameter { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} TanhQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TANH_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c new file mode 100644 index 00000000..643ff0c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/topk_int8.h" + +int DescendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element; +} + +int AscendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element; +} + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int8_t *cur_input_data = (int8_t *)input_data; + int8_t *cur_output_data = (int8_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmpInt8); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), AscendCmpInt8); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h new file mode 100644 index 00000000..9910f355 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/topk_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TOPK_INT8_H_ +#define NNACL_INT8_TOPK_INT8_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/topk_fp32.h" + +typedef struct TopkNodeInt8 { + int8_t element; + int32_t index; +} TopkNodeInt8; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TOPK_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c new file mode 100644 index 00000000..65c1a590 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/transpose_int8.h" +void TransposeDim2Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } + return; +} + +void TransposeDim3Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int p = 0; p < output5; ++p) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + p] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + p * stride5]; + } + } + } + } + } + } +} + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param) { + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); + + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + const int num_axes = transpose_param->num_axes_; + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, transpose_param->data_num_ * sizeof(int8_t)); + return NNACL_OK; + } + + switch (num_axes) { + case 2: + TransposeDim2Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 3: + TransposeDim3Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 4: + TransposeDim4Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 5: + TransposeDim5Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 6: + TransposeDim6Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + default: + return NNACL_ERR; + } + + return NNACL_OK; +} + +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num) { + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); + NNACL_CHECK_ZERO_RETURN(thread_num); + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + int num_axes = transpose_param->num_axes_; + size_t data_size = (size_t)((*out_strides) * output_shape[0]); + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + size_t count = data_size - task_offset; + if (data_size < task_offset) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = (int)idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h new file mode 100644 index 00000000..50fbafee --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/transpose_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TRANSPOSE_INT8_H_ +#define NNACL_INT8_TRANSPOSE_INT8_H_ + +#include +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param); +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TRANSPOSE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c new file mode 100644 index 00000000..44d10ed9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/int8/unsqueeze_int8.h" +#include "nnacl_c/unsqueeze_parameter.h" +#include "nnacl_c/errorcode.h" + +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id) { + float output_scale = para_->quant_arg.out_quant_args_.scale_; + NNACL_CHECK_ZERO_RETURN_ERR(output_scale); + int8_t output_zp = para_->quant_arg.out_quant_args_.zp_; + float input_scale = para_->quant_arg.in_quant_args_.scale_; + int8_t input_zp = para_->quant_arg.in_quant_args_.zp_; + + for (int i = task_id; i < (int)data_size; i += para_->thread_count_) { + output_ptr[i] = output_zp + round(1 / output_scale * input_scale * (input_ptr[i] - input_zp)); + } + return 0; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h new file mode 100644 index 00000000..1649945d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/int8/unsqueeze_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_UNSQUEEZE_INT8_H_ +#define NNACL_INT8_UNSQUEEZE_INT8_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_UNSQUEEZE_INT8_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c new file mode 100644 index 00000000..63b930c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/DeconvMatMulAvx.c @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/op_base.h" + +void Deconv4X8AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + weight += C8NUM; + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + src += C4NUM; + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); +} + +void Deconv4X16AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + weight += C16NUM; + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + src += C4NUM; + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + res5 = _mm256_fmadd_ps(tmp1, w1, res5); + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res8 = _mm256_fmadd_ps(tmp2, w1, res8); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + res11 = _mm256_fmadd_ps(tmp3, w1, res11); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); +} + +void Deconv4X24AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res3 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res6 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res9 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + __m256 res12 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + __m256 w2 = _mm256_loadu_ps(weight + C16NUM); + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + res3 = _mm256_fmadd_ps(tmp, w2, res3); + tmp = _mm256_set1_ps(*(src + C1NUM)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + res5 = _mm256_fmadd_ps(tmp, w1, res5); + res6 = _mm256_fmadd_ps(tmp, w2, res6); + tmp = _mm256_set1_ps(*(src + C2NUM)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + res8 = _mm256_fmadd_ps(tmp, w1, res8); + res9 = _mm256_fmadd_ps(tmp, w2, res9); + tmp = _mm256_set1_ps(*(src + C3NUM)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + res11 = _mm256_fmadd_ps(tmp, w1, res11); + res12 = _mm256_fmadd_ps(tmp, w2, res12); + weight += C24NUM; + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); + + _mm256_storeu_ps(dst + C2NUM * stride, res3); + _mm256_storeu_ps(dst + C2NUM * stride + C8NUM, res6); + _mm256_storeu_ps(dst + C2NUM * stride + C16NUM, res9); + _mm256_storeu_ps(dst + C2NUM * stride + C24NUM, res12); +} + +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, const int plane) { + NNACL_CHECK_ZERO_RETURN(plane); + int col_num = 0; + int col_block = UP_DIV(col / plane, C8NUM); + DeconvAvxKernel kernel[3] = {Deconv4X8AvxKernel, Deconv4X16AvxKernel, Deconv4X24AvxKernel}; + for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) { + col_num = MSMIN(C3NUM, col_block - col_tmp); + for (int p = 0; p < plane; ++p) { + for (int r = 0; r < row; r += C4NUM) { + kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth, + c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth, + row * C8NUM * plane); + } + } + } +} + +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride) { + __m256 res[C12NUM]; + __m256 w[C3NUM]; + for (int i = 0; i < C12NUM; ++i) { + res[i] = _mm256_setzero_ps(); + } + for (int d = 0; d < depth; ++d) { + for (int c = 0; c < col; ++c) { + w[c] = _mm256_loadu_ps(weight + c * C8NUM); + } + weight += col * C8NUM; + for (int r = 0; r < row; ++r) { // C4NUm + __m256 tmp = _mm256_set1_ps(*src); + for (int c = 0; c < col; ++c) { // 3 * C8NUM + res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]); + } + src += 1; + } + } + // write + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]); + } + dst += stride; + } +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c new file mode 100644 index 00000000..91d7d154 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/PostFuncBiasReluC8.c @@ -0,0 +1,352 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/avx/common_utils.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + _mm256_storeu_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + _mm256_storeu_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src += C32NUM; + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + _mm256_storeu_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + _mm256_storeu_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + _mm256_storeu_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c new file mode 100644 index 00000000..7eb313ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/TiledC8MatMulFp32.c @@ -0,0 +1,274 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl_c/fp32/common_func_fp32.h" + +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic8, size_t oc8) { + const float *src_tmp = src; + for (int i = 0; i < oc8; ++i) { + src = src_tmp; +#ifndef ENABLE_DEBUG + asm volatile( + "vxorps %%xmm0, %%xmm0, %%xmm0\n" + "vmovaps %%ymm0, %%ymm1\n" + "vmovaps %%ymm0, %%ymm2\n" + "vmovaps %%ymm0, %%ymm3\n" + "vmovaps %%ymm0, %%ymm4\n" + "vmovaps %%ymm0, %%ymm5\n" + "vmovaps %%ymm0, %%ymm6\n" + "vmovaps %%ymm0, %%ymm7\n" + : /* no input */ + : /* no input */ + : "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7"); +#else + register __m256 dst1 asm("ymm0") = _mm256_setzero_ps(); + register __m256 dst2 asm("ymm1") = _mm256_setzero_ps(); + register __m256 dst3 asm("ymm2") = _mm256_setzero_ps(); + register __m256 dst4 asm("ymm3") = _mm256_setzero_ps(); + register __m256 dst5 asm("ymm4") = _mm256_setzero_ps(); + register __m256 dst6 asm("ymm5") = _mm256_setzero_ps(); + register __m256 dst7 asm("ymm6") = _mm256_setzero_ps(); + register __m256 dst8 asm("ymm7") = _mm256_setzero_ps(); +#endif + for (size_t ic8_tmp = 0; ic8_tmp < ic8; ++ic8_tmp) { +#ifndef ENABLE_DEBUG + asm volatile( + // 1 + "vmovups (%1), %%ymm8\n" + + "vbroadcastss (%0), %%ymm9\n" + "vbroadcastss 32(%0), %%ymm10\n" + "vbroadcastss 64(%0), %%ymm11\n" + "vbroadcastss 96(%0), %%ymm12\n" + "vbroadcastss 128(%0), %%ymm13\n" + "vbroadcastss 160(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 192(%0), %%ymm9\n" + "vbroadcastss 224(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 2 + "vmovups 32(%1), %%ymm15\n" + + "vbroadcastss 4(%0), %%ymm11\n" + "vbroadcastss 36(%0), %%ymm12\n" + "vbroadcastss 68(%0), %%ymm13\n" + "vbroadcastss 100(%0), %%ymm14\n" + "vbroadcastss 132(%0), %%ymm9\n" + "vbroadcastss 164(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 196(%0), %%ymm11\n" + "vbroadcastss 228(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + + // 3 + "vmovups 64(%1), %%ymm8\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 40(%0), %%ymm14\n" + "vbroadcastss 72(%0), %%ymm9\n" + "vbroadcastss 104(%0), %%ymm10\n" + "vbroadcastss 136(%0), %%ymm11\n" + "vbroadcastss 168(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm5\n" + + "vbroadcastss 200(%0), %%ymm13\n" + "vbroadcastss 232(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm7\n" + + // 4 + "vmovups 96(%1), %%ymm15\n" + + "vbroadcastss 12(%0), %%ymm9\n" + "vbroadcastss 44(%0), %%ymm10\n" + "vbroadcastss 76(%0), %%ymm11\n" + "vbroadcastss 108(%0), %%ymm12\n" + "vbroadcastss 140(%0), %%ymm13\n" + "vbroadcastss 172(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n" + + "vbroadcastss 204(%0), %%ymm9\n" + "vbroadcastss 236(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm7\n" + + // 5 + "vmovups 128(%1), %%ymm8\n" + + "vbroadcastss 16(%0), %%ymm11\n" + "vbroadcastss 48(%0), %%ymm12\n" + "vbroadcastss 80(%0), %%ymm13\n" + "vbroadcastss 112(%0), %%ymm14\n" + "vbroadcastss 144(%0), %%ymm9\n" + "vbroadcastss 176(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm5\n" + + "vbroadcastss 208(%0), %%ymm11\n" + "vbroadcastss 240(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm7\n" + + // 6 + "vmovups 160(%1), %%ymm15\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 52(%0), %%ymm14\n" + "vbroadcastss 84(%0), %%ymm9\n" + "vbroadcastss 116(%0), %%ymm10\n" + "vbroadcastss 148(%0), %%ymm11\n" + "vbroadcastss 180(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + + "vbroadcastss 212(%0), %%ymm13\n" + "vbroadcastss 244(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm7\n" + + // 7 + "vmovups 192(%1), %%ymm8\n" + + "vbroadcastss 24(%0), %%ymm9\n" + "vbroadcastss 56(%0), %%ymm10\n" + "vbroadcastss 88(%0), %%ymm11\n" + "vbroadcastss 120(%0), %%ymm12\n" + "vbroadcastss 152(%0), %%ymm13\n" + "vbroadcastss 184(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 216(%0), %%ymm9\n" + "vbroadcastss 248(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 8 + "vmovups 224(%1), %%ymm15\n" + + "vbroadcastss 28(%0), %%ymm11\n" + "vbroadcastss 60(%0), %%ymm12\n" + "vbroadcastss 92(%0), %%ymm13\n" + "vbroadcastss 124(%0), %%ymm14\n" + "vbroadcastss 156(%0), %%ymm9\n" + "vbroadcastss 188(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 220(%0), %%ymm11\n" + "vbroadcastss 252(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + : + : "r"(src), "r"(weight) + : "memory"); +#else + for (int j = 0; j < C8NUM; ++j) { + __m256 weight_data = _mm256_loadu_ps(weight + j * C8NUM); + dst1 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j)), dst1); + dst2 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C8NUM)), dst2); + dst3 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C16NUM)), dst3); + dst4 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C24NUM)), dst4); + dst5 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C32NUM)), dst5); + dst6 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C40NUM)), dst6); + dst7 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C48NUM)), dst7); + dst8 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C56NUM)), dst8); + } +#endif + src += C64NUM; + weight += C64NUM; + } +#ifndef ENABLE_DEBUG + asm volatile( + "vmovups %%ymm0, (%[dst])\n\t" + "vmovups %%ymm1, 32(%[dst])\n\t" + "vmovups %%ymm2, 64(%[dst])\n\t" + "vmovups %%ymm3, 96(%[dst])\n\t" + "vmovups %%ymm4, 128(%[dst])\n\t" + "vmovups %%ymm5, 160(%[dst])\n\t" + "vmovups %%ymm6, 192(%[dst])\n\t" + "vmovups %%ymm7, 224(%[dst])\n\t" + : + : [dst] "r"(dst) + : "memory", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7"); +#else + _mm256_storeu_ps(dst, dst1); + _mm256_storeu_ps(dst + C8NUM, dst2); + _mm256_storeu_ps(dst + C16NUM, dst3); + _mm256_storeu_ps(dst + C24NUM, dst4); + _mm256_storeu_ps(dst + C32NUM, dst5); + _mm256_storeu_ps(dst + C40NUM, dst6); + _mm256_storeu_ps(dst + C48NUM, dst7); + _mm256_storeu_ps(dst + C56NUM, dst8); +#endif + dst += cal_num; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c new file mode 100644 index 00000000..c67513fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradPostFuncBiasReluC8.c @@ -0,0 +1,357 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/avx/common_utils.h" + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc8div + oc8mod; + plane_stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM + plane_stride; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + _mm256_stream_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + _mm256_stream_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + _mm256_stream_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + _mm256_stream_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + src += C32NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + _mm256_stream_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c new file mode 100644 index 00000000..634d19f4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/WinogradTransAvx.c @@ -0,0 +1,355 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM; + size_t S_step = length * w * C8NUM; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(SK + C4NUM * S_step); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(SK + C5NUM * S_step); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(SK + C6NUM * S_step); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(SK + C7NUM * S_step); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C8NUM * S_step - len_c8; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + int len_tmp = length; + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C16NUM, M += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + __m256 s44 = _mm256_loadu_ps(SK + C3NUM * S_step + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + M1 = _mm256_add_ps(M1, M2); + M4 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M4); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C4NUM * S_step - len_c8; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, SK += C24NUM, M += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(SK + S_step + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s333 = _mm256_loadu_ps(SK + C2NUM * S_step + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M5, M6); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C3NUM * S_step - len_c8; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s0 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s0, k1, M1); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += S_step - len_c8; + } + SW += len_c8; + M += len_c8; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM, k_step = len_c8 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c8) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + const float *S2 = SK + len_c8, *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8, *S5 = S4 + len_c8; + const float *S6 = S5 + len_c8, *S7 = S6 + len_c8, *S8 = S7 + len_c8; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, + S4 += C8NUM, S5 += C8NUM, S6 += C8NUM, S7 += C8NUM, S8 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(S5); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(S6); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(S7); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(S8); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S7; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8; + int len_tmp = length; + for (; len_tmp >= C2NUM; + len_tmp -= C2NUM, M += C16NUM, SK += C16NUM, S2 += C16NUM, S3 += C16NUM, S4 += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s4 = _mm256_loadu_ps(S4); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s44 = _mm256_loadu_ps(S4 + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S4; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, M += C24NUM, SK += C24NUM, S2 += C24NUM, S3 += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(S2 + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s333 = _mm256_loadu_ps(S3 + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M6, M5); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + _mm256_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c new file mode 100644 index 00000000..8335748a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.c @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl_c/intrinsics/avx/common_utils.h" +#include + +__m128i _mm_adds_epi32(__m128i a, __m128i b) { + __m128i int_min = _mm_set1_epi32(0x80000000); + __m128i int_max = _mm_set1_epi32(0x7FFFFFFF); + + const __m128i res = _mm_add_epi32(a, b); + const __m128i sign_and = _mm_and_si128(a, b); + const __m128i sign_or = _mm_or_si128(a, b); + + const __m128i min_sat_mask = _mm_andnot_si128(res, sign_and); + const __m128i max_sat_mask = _mm_andnot_si128(sign_or, res); + const __m128 res_temp = + _mm_blendv_ps(_mm_castsi128_ps(res), _mm_castsi128_ps(int_min), _mm_castsi128_ps(min_sat_mask)); + return _mm_castps_si128(_mm_blendv_ps(res_temp, _mm_castsi128_ps(int_max), _mm_castsi128_ps(max_sat_mask))); +} + +__m128i _mm_rshr_epi32(__m128i a, int shift) { + const __m128i vmask = _mm_cmpgt_epi32(_mm_setzero_si128(), a); + const __m128i vabs_a = _mm_sub_epi32(_mm_xor_si128(a, vmask), vmask); + const __m128i tmp_res = _mm_srli_epi32(vabs_a, shift); + return _mm_xor_si128(tmp_res, vmask); +} + +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b) { + const __m128i tmp_a_lo = _mm_unpacklo_epi32(a, _mm_setzero_si128()); + const __m128i tmp_a_hi = _mm_unpackhi_epi32(a, _mm_setzero_si128()); + const __m256i tmp_a_256 = _mm256_set_m128i(tmp_a_hi, tmp_a_lo); + const __m128i tmp_b_lo = _mm_unpacklo_epi32(b, _mm_setzero_si128()); + const __m128i tmp_b_hi = _mm_unpackhi_epi32(b, _mm_setzero_si128()); + const __m256i tmp_b_256 = _mm256_set_m128i(tmp_b_hi, tmp_b_lo); + __m256i tmp_out = _mm256_mul_epi32(tmp_a_256, tmp_b_256); + tmp_out = _mm256_add_epi64(tmp_out, _mm256_set1_epi64x(1ll << 30)); + const __m256i vmask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), tmp_out); + const __m256i vabs_tmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + tmp_out = _mm256_srli_epi64(vabs_tmp_out, 31); + const __m256i vtmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + const int32_t max_32bit = (1ll << 31) - 1; + const int32_t min_32bit = -(1ll << 31); + int64_t *tmp_out_ptr = (int64_t *)(&vtmp_out); + int32_t r1 = tmp_out_ptr[0] > max_32bit ? max_32bit : tmp_out_ptr[0]; + r1 = r1 < min_32bit ? min_32bit : r1; + int32_t r2 = tmp_out_ptr[1] > max_32bit ? max_32bit : tmp_out_ptr[1]; + r2 = r2 < min_32bit ? min_32bit : r2; + int32_t r3 = tmp_out_ptr[2] > max_32bit ? max_32bit : tmp_out_ptr[2]; + r3 = r3 < min_32bit ? min_32bit : r3; + int32_t r4 = tmp_out_ptr[3] > max_32bit ? max_32bit : tmp_out_ptr[3]; + r4 = r4 < min_32bit ? min_32bit : r4; + return _mm_set_epi32(r4, r3, r2, r1); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h new file mode 100644 index 00000000..0b80a83a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/avx/common_utils.h @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ +#define MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ + +#ifdef _MSC_VER +#include +#else +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __GNUC__ +#if __GNUC__ < 8 +#define _mm256_set_m128i(xmm1, xmm2) \ + _mm256_permute2f128_si256(_mm256_castsi128_si256(xmm1), _mm256_castsi128_si256(xmm2), 2) +#define _mm256_set_m128f(xmm1, xmm2) \ + _mm256_permute2f128_ps(_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) +#endif +#endif + +#define AVX_ACT_RELU 1 +#define AVX_ACT_RELU6 3 + +// Signed saturating Add +__m128i _mm_adds_epi32(__m128i a, __m128i b); + +// Signed rounding shift right +__m128i _mm_rshr_epi32(__m128i a, int shift); + +// Signed saturating Rounding Doubling Multiply return High half +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b); + +static inline void ActBlock1Avx(__m256 *v1, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2Avx(__m256 *v1, __m256 *v2, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock8Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, size_t relu_type) { + __m256 relu6 = _mm256_set1_ps(6.0); + __m256 zero = _mm256_setzero_ps(); + switch (relu_type) { + case AVX_ACT_RELU6: + *v1 = _mm256_min_ps(*v1, relu6); + *v2 = _mm256_min_ps(*v2, relu6); + *v3 = _mm256_min_ps(*v3, relu6); + *v4 = _mm256_min_ps(*v4, relu6); + *v5 = _mm256_min_ps(*v5, relu6); + *v6 = _mm256_min_ps(*v6, relu6); + *v7 = _mm256_min_ps(*v7, relu6); + *v8 = _mm256_min_ps(*v8, relu6); + case AVX_ACT_RELU: + *v1 = _mm256_max_ps(*v1, zero); + *v2 = _mm256_max_ps(*v2, zero); + *v3 = _mm256_max_ps(*v3, zero); + *v4 = _mm256_max_ps(*v4, zero); + *v5 = _mm256_max_ps(*v5, zero); + *v6 = _mm256_max_ps(*v6, zero); + *v7 = _mm256_max_ps(*v7, zero); + *v8 = _mm256_max_ps(*v8, zero); + break; + default: + break; + } +} + +static inline void ActBlock12Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, __m256 *v9, __m256 *v10, __m256 *v11, __m256 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + *v5 = _mm256_max_ps(zero_ma, *v5); + *v6 = _mm256_max_ps(zero_ma, *v6); + *v7 = _mm256_max_ps(zero_ma, *v7); + *v8 = _mm256_max_ps(zero_ma, *v8); + *v9 = _mm256_max_ps(zero_ma, *v9); + *v10 = _mm256_max_ps(zero_ma, *v10); + *v11 = _mm256_max_ps(zero_ma, *v11); + *v12 = _mm256_max_ps(zero_ma, *v12); + } + if (relu6) { + __m256 relu6_ma = _mm256_set1_ps(6.0f); + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + *v5 = _mm256_min_ps(relu6_ma, *v5); + *v6 = _mm256_min_ps(relu6_ma, *v6); + *v7 = _mm256_min_ps(relu6_ma, *v7); + *v8 = _mm256_min_ps(relu6_ma, *v8); + *v9 = _mm256_min_ps(relu6_ma, *v9); + *v10 = _mm256_min_ps(relu6_ma, *v10); + *v11 = _mm256_min_ps(relu6_ma, *v11); + *v12 = _mm256_min_ps(relu6_ma, *v12); + } +} + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h new file mode 100644 index 00000000..5918725b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx512_instructions.h @@ -0,0 +1,446 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#ifdef _MSC_VER +#include +#define MS_F32X16_GETI(src, i) src.m512_f32[i] +#define MS512_F32_GETI(src, i) src.m512_f32[i] +#else +#include +#define MS_F32X16_GETI(src, i) src[i] +#define MS512_F32_GETI(src, i) src[i] +#endif + +#pragma GCC push_options +#pragma GCC target("avx512f") + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X16 __m512 +#define MS_FLOAT512_F32 __m512 +#define MS_INT32X16 __m512i +#define MS_INT512_EPI32 __m512i +#define MS_MASK512_TYPE __mmask16 +#define MS_LD512_F32 _mm512_loadu_ps +#define MS_LD512_EPI32(src) _mm512_loadu_si512((__m512i const *)(src)) +#define MS_LD512_HALF_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD512_F32 _mm512_add_ps +#define MS_ADD512_EPI32 _mm512_add_epi32 +#define MS_MOV512_F32 _mm512_set1_ps +#define MS_MOV512_EPI32 _mm512_set1_epi32 +#define MS_MOV512_VAL0_F32 _mm512_setzero_ps() +#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1) +#define MS_ST512_F32 _mm512_storeu_ps +#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) +#define MS_ST512_HALF_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB512_F32 _mm512_sub_ps +#define MS_SUB512_EPI32 _mm512_sub_epi32 +#define MS_MAX512_F32 _mm512_max_ps +#define MS_MAX512_EPI32 _mm512_max_epi32 +#define MS_MIN512_F32 _mm512_min_ps +#define MS_MIN512_EPI32 _mm512_min_epi32 +#define MS_SQRT512_F32 _mm512_sqrt_ps +#define MS_RSQRT512_F32 _mm512_rsqrt14_ps +#define MS_SIN512_F32 _mm512_sin_ps +#define MS_ERF512_F32 _mm512_erf_ps +#define MS_ABS512_F32 _mm512_abs_ps +#define MS_ABS512_EPI32 _mm512_abs_epi32 + +#define MS_ROUND512_F32(src) \ + _mm512_add_round_ps(src, _mm512_set1_ps(0.0f), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR512_F32 _mm512_floor_ps +#define MS_CEIL512_F32 _mm512_ceil_ps +#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) +#define MS_MUL512_EPI32(src1, src2) _mm512_mullo_epi32(src1, src2) +#define MS_FMADD512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3) +#define MS_FMSUB512_F32(src1, src2, src3) _mm512_fmsub_ps(src1, src2, src3) +#define MS_FSMUL512_F32(src1, src2, src3) _mm512_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) +#define MS_MUL512_N_F32(src1, src2) _mm512_mul_ps(src1, _mm512_set1_ps(src2)) +#define MS_MUL512_N_EPI32(src1, src2) _mm512_mullo_epi32(src1, _mm512_set1_epi32(src2)) +#define MS_DIV512_N_F32(src1, src2) _mm512_div_ps(src1, _mm512_set1_ps(src2)) +#define MS_SLLI512_EPI32(src1, src2) _mm512_slli_epi32(src1, src2) +#define MS_CVT512PS_EPI32(src) _mm512_cvttps_epi32(src) +#define MS_CVT512EPI32_PS(src) _mm512_cvtepi32_ps(src) // truncate float to int +#define MS_CMP512_F32(src1, src2, src3) _mm512_cmp_ps_mask(src1, src2, src3) +#define MS_CMPGT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 30) +#define MS_CMPLE512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 18) +#define MS_CMPLT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 17) +#define MS_CMPGT512_EPI32(src1, src2) _mm512_cmpgt_epi32(src1, src2) +#define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2) +#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2) +#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) +#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src) +#define MS_GET_MAX512_F32(src) _mm512_reduce_max_ps(src) +#define MS_GET_MIN512_F32(src) _mm512_reduce_min_ps(src) +#define MS_GET_SUM512_F32(src) _mm512_reduce_add_ps(src) +#define MS_AND512_MASK(src1, src2) _mm512_kand(src1, src2) + +#define MS512_SRLI_EPI32(src1, src2) _mm512_srli_epi32(src1, src2) +#define MS512_AND_EPI32(src1, src2) _mm512_and_si512(src1, src2) +#define MS512_CASTPS_EPI32(src) _mm512_castps_si512(src) +#define MS_OR512_EPI32(src1, src2) _mm512_or_epi32(src1, src2) +#define MS_AND512_EPI32(src1, src2) _mm512_and_epi32(src1, src2) +#define MS_AND512_F32(src1, src2) \ + _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(src1), _mm512_castps_si512(src2))) + +static inline MS_FLOAT512_F32 SIMD_SIGN512_F32(MS_FLOAT512_F32 src) { + MS_FLOAT512_F32 abs_src = MS_ABS512_F32(src); + MS_FLOAT512_F32 sign = MS_DIV512_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS512_F32(src, abs_src) MS_DIV512_F32(abs_src, src) + +static inline MS_FLOAT512_F32 MS_OR512_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_OR512_EPI32(MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS512_ANDNOT_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_AND512_EPI32(~MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS_AND512_MASK_F32(MS_MASK512_TYPE mask, MS_FLOAT512_F32 value) { + /* mask = T ? value ; 0 */ + MS_FLOAT512_F32 zeros = _mm512_set1_ps(0.0f); + return _mm512_mask_blend_ps(mask, zeros, value); +} + +static inline MS_FLOAT32X16 MS_POW512_F32(MS_FLOAT32X16 src1, MS_FLOAT32X16 src2) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = powf(MS512_F32_GETI(src1, 0), MS512_F32_GETI(src2, 0)); + MS512_F32_GETI(dst, 1) = powf(MS512_F32_GETI(src1, 1), MS512_F32_GETI(src2, 1)); + MS512_F32_GETI(dst, 2) = powf(MS512_F32_GETI(src1, 2), MS512_F32_GETI(src2, 2)); + MS512_F32_GETI(dst, 3) = powf(MS512_F32_GETI(src1, 3), MS512_F32_GETI(src2, 3)); + MS512_F32_GETI(dst, 4) = powf(MS512_F32_GETI(src1, 4), MS512_F32_GETI(src2, 4)); + MS512_F32_GETI(dst, 5) = powf(MS512_F32_GETI(src1, 5), MS512_F32_GETI(src2, 5)); + MS512_F32_GETI(dst, 6) = powf(MS512_F32_GETI(src1, 6), MS512_F32_GETI(src2, 6)); + MS512_F32_GETI(dst, 7) = powf(MS512_F32_GETI(src1, 7), MS512_F32_GETI(src2, 7)); + MS512_F32_GETI(dst, 8) = powf(MS512_F32_GETI(src1, 8), MS512_F32_GETI(src2, 8)); + MS512_F32_GETI(dst, 9) = powf(MS512_F32_GETI(src1, 9), MS512_F32_GETI(src2, 9)); + MS512_F32_GETI(dst, 10) = powf(MS512_F32_GETI(src1, 10), MS512_F32_GETI(src2, 10)); + MS512_F32_GETI(dst, 11) = powf(MS512_F32_GETI(src1, 11), MS512_F32_GETI(src2, 11)); + MS512_F32_GETI(dst, 12) = powf(MS512_F32_GETI(src1, 12), MS512_F32_GETI(src2, 12)); + MS512_F32_GETI(dst, 13) = powf(MS512_F32_GETI(src1, 13), MS512_F32_GETI(src2, 13)); + MS512_F32_GETI(dst, 14) = powf(MS512_F32_GETI(src1, 14), MS512_F32_GETI(src2, 14)); + MS512_F32_GETI(dst, 15) = powf(MS512_F32_GETI(src1, 15), MS512_F32_GETI(src2, 15)); + return dst; +} + +static inline MS_FLOAT32X16 MS_COS512_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 pi = {PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X16 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X16 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT512_F32 src_abs = MS_ABS512_F32(src); + MS_FLOAT512_F32 src_cycle = + MS_ADD512_F32(MS_MUL512_F32(MS_FLOOR512_F32(MS_MUL512_F32(MS_ADD512_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT512_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT512_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT512_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT512_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT512_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X16 square = MS_MUL512_F32(src_cycle, src_cycle); + + MS_FLOAT32X16 tmp = + MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp, neg), square), data2); + MS_FLOAT512_F32 res = MS_ADD512_F32( + MS_MUL512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X16 MS512_LOG_F32(MS_FLOAT32X16 src) { + const MS_INT512_EPI32 gFloatExpMask = MS_MOV512_EPI32(0xffULL << 23); + const MS_INT512_EPI32 gFloatExp0 = MS_MOV512_EPI32(127ULL << 23); + const MS_INT512_EPI32 gExpNormalizer = MS_MOV512_EPI32(127); + static const MS_FLOAT512_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT512_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, + 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT512_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, + 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT512_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT512_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT512_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT512_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X16 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT512_EPI32 exps32 = MS512_SRLI_EPI32(MS512_AND_EPI32(gFloatExpMask, MS512_CASTPS_EPI32(src)), 23); + const MS_INT512_EPI32 normExps = MS_SUB512_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X16 expsPD = MS_CVT512EPI32_PS(normExps); + const MS_FLOAT32X16 y = + MS_OR512_F32(MS_CAST512_F32_S32(gFloatExp0), MS512_ANDNOT_F32(MS_CAST512_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X16 div = MS_DIV512_F32(MS_ADD512_F32(y, neg), MS_ADD512_F32(y, pos)); + MS_FLOAT32X16 square = MS_MUL512_F32(div, div); + + MS_FLOAT32X16 tmp = MS_ADD512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, tmp), data4)); + MS_FLOAT32X16 res = + MS_ADD512_F32(MS_MUL512_F32(ln2, expsPD), MS_MUL512_F32(MS_MUL512_F32(div, MS_ADD512_F32(tmp1, data5)), data6)); + MS_MASK512_TYPE mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(-INFINITY), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(INFINITY), mask); + mask = MS_CMPLT512_F32(src, MS_MOV512_F32(0.0f)); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_UNORD_Q); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + return res; +} + +#define MS_DIV512_EPI32(src1, src2) \ + _mm512_cvttps_epi32(MS_DIV512_F32(_mm512_cvtepi32_ps(src1), _mm512_cvtepi32_ps(src2))) + +#define MS512_INT16_TO_FLOAT16(src) _mm512_cvtepi16_ph(src) +#define MS512_FLOAT16_TO_INT16(src) _mm512_cvttph_epi16(src) + +#define MS512_INT32_TO_FLOAT16(src) _mm512_cvtepi32_ph(src) +#define MS512_FLOAT16_TO_INT32(src) _mm512_cvttph_epi32(src) + +#define MS512_INT32_TO_FLOAT32(src) _mm512_cvtepi32_ps(src) +#define MS512_FLOAT32_TO_INT32(src) _mm512_cvttps_epi32(src) +#define MS512_FLOAT16_TO_FLOAT32(src) _mm512_cvtph_ps(src) +#define MS512_FLOAT32_TO_FLOAT16(src1, src2) _mm512_cvtps_ph(src1, src2) + +#define MS512_INT64_TO_FLOAT32(src) _mm512_cvtepi64_ps(src) +#define MS512_FLOAT32_TO_INT64(src) _mm512_cvttps_epi64(src) + +#define MS512_INT64_TO_FLOAT16(src) _mm512_cvtepi64_ph(src) +#define MS512_FLOAT16_TO_INT64(src) _mm512_cvttph_epi64(src) + +#define MS512_INT32_TO_FLOAT64(src) _mm512_cvtepi32_pd(src) +#define MS512_FLOAT64_TO_INT32(src) _mm512_cvttpd_epi32(src) + +#define MS512_INT64_TO_FLOAT64(src) _mm512_cvtepi64_pd(src) +#define MS512_FLOAT64_TO_INT64(src) _mm512_cvttpd_epi64(src) + +#define MS512_INT16_TO_INT32(src) _mm512_cvtepi16_epi32(src) +#define MS512_INT16_TO_INT64(src) _mm512_cvtepi16_epi64(src) +#define MS512_INT32_TO_INT16(src) _mm512_cvtepi32_epi16(src) +#define MS512_INT32_TO_INT64(src) _mm512_cvtepi32_epi64(src) +#define MS512_INT64_TO_INT16(src) _mm512_cvtepi64_epi16(src) +#define MS512_INT64_TO_INT32(src) _mm512_cvtepi64_epi32(src) + +static inline MS_FLOAT32X16 simd_exp512_f32(MS_FLOAT32X16 input) { + static MS_FLOAT32X16 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X16 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X16 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, + 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, + 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, + 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, + 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + + input = MS_MAX512_F32(minv, MS_MIN512_F32(input, maxv)); + MS_INT32X16 integer = MS_CVT512PS_EPI32(MS_FLOOR512_F32(MS_FMADD512_F32(input, param[6], param[4]))); + MS_FLOAT32X16 decimal = MS_SUB512_F32(input, MS_MUL512_F32(MS_CVT512EPI32_PS(integer), param[0])); + MS_INT32X16 int_exp = MS_SLLI512_EPI32(MS_ADD512_EPI32(integer, MS_MOV512_EPI32(126)), 23); + MS_FLOAT32X16 tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X16 decimal_exp = MS_FMADD512_F32(decimal, tmp, param[5]); + return MS_MUL512_F32(param[7], MS_MUL512_F32(decimal_exp, MS_CAST512_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X16 simd_hexp512_f32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = exp(MS512_F32_GETI(src, 0)); + MS512_F32_GETI(dst, 1) = exp(MS512_F32_GETI(src, 1)); + MS512_F32_GETI(dst, 2) = exp(MS512_F32_GETI(src, 2)); + MS512_F32_GETI(dst, 3) = exp(MS512_F32_GETI(src, 3)); + MS512_F32_GETI(dst, 4) = exp(MS512_F32_GETI(src, 4)); + MS512_F32_GETI(dst, 5) = exp(MS512_F32_GETI(src, 5)); + MS512_F32_GETI(dst, 6) = exp(MS512_F32_GETI(src, 6)); + MS512_F32_GETI(dst, 7) = exp(MS512_F32_GETI(src, 7)); + MS512_F32_GETI(dst, 8) = exp(MS512_F32_GETI(src, 8)); + MS512_F32_GETI(dst, 9) = exp(MS512_F32_GETI(src, 9)); + MS512_F32_GETI(dst, 10) = exp(MS512_F32_GETI(src, 10)); + MS512_F32_GETI(dst, 11) = exp(MS512_F32_GETI(src, 11)); + MS512_F32_GETI(dst, 12) = exp(MS512_F32_GETI(src, 12)); + MS512_F32_GETI(dst, 13) = exp(MS512_F32_GETI(src, 13)); + MS512_F32_GETI(dst, 14) = exp(MS512_F32_GETI(src, 14)); + MS512_F32_GETI(dst, 15) = exp(MS512_F32_GETI(src, 15)); + return dst; +} + +static inline void simd_exp512(MS_FLOAT32X16 input, float *dst) { + MS_FLOAT32X16 res = simd_exp512_f32(input); + MS_ST512_F32(dst, res); +} + +static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, + 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X16 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, + 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X16 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X16 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, + 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X16 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, + 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X16 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, + 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X16 square = MS_MUL512_F32(src, src); + MS_FLOAT32X16 a = + MS_MUL512_F32(MS_FMADD512_F32(MS_FMADD512_F32(MS_ADD512_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X16 b = + MS_FMADD512_F32(MS_FMADD512_F32(MS_FMADD512_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X16 res = MS_DIV512_F32(a, b); + MS_FLOAT32X16 up_limit = MS_MOV512_F32(5.0f); + MS_FLOAT32X16 down_limit = MS_MOV512_F32(-5.0f); + MS_MASK512_TYPE up_mask = MS_CMPGT512_F32(src, up_limit); + MS_MASK512_TYPE down_mask = MS_CMPLT512_F32(src, down_limit); + res = MS_BLEND512_F32(res, pos, up_mask); + res = MS_BLEND512_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH512_F32 MS_TANHX16_F32 + +static inline MS_FLOAT32X16 MS512_ERF_F32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS_F32X16_GETI(dst, 0) = erff(MS_F32X16_GETI(src, 0)); + MS_F32X16_GETI(dst, 1) = erff(MS_F32X16_GETI(src, 1)); + MS_F32X16_GETI(dst, 2) = erff(MS_F32X16_GETI(src, 2)); + MS_F32X16_GETI(dst, 3) = erff(MS_F32X16_GETI(src, 3)); + MS_F32X16_GETI(dst, 4) = erff(MS_F32X16_GETI(src, 4)); + MS_F32X16_GETI(dst, 5) = erff(MS_F32X16_GETI(src, 5)); + MS_F32X16_GETI(dst, 6) = erff(MS_F32X16_GETI(src, 6)); + MS_F32X16_GETI(dst, 7) = erff(MS_F32X16_GETI(src, 7)); + MS_F32X16_GETI(dst, 8) = erff(MS_F32X16_GETI(src, 8)); + MS_F32X16_GETI(dst, 9) = erff(MS_F32X16_GETI(src, 9)); + MS_F32X16_GETI(dst, 10) = erff(MS_F32X16_GETI(src, 10)); + MS_F32X16_GETI(dst, 11) = erff(MS_F32X16_GETI(src, 11)); + MS_F32X16_GETI(dst, 12) = erff(MS_F32X16_GETI(src, 12)); + MS_F32X16_GETI(dst, 13) = erff(MS_F32X16_GETI(src, 13)); + MS_F32X16_GETI(dst, 14) = erff(MS_F32X16_GETI(src, 14)); + MS_F32X16_GETI(dst, 15) = erff(MS_F32X16_GETI(src, 15)); + return dst; +} + +#define MS_LOAD512X8_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \ + MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \ + MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \ + MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \ + MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num); + +#define MS_LOAD512X4_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); + +#define MS_FMADD512X8_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA512_F32(dst##8, src##8, weight); + +#define MS_FMADD512X4_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLA512_F32(src##4, weight, dst##4); + +#define MS_SET_ZERO512X8_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##8 = _mm512_setzero_ps(); + +#define MS_SET_ZERO512X4_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); + +#pragma GCC pop_options + +#endif // NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h new file mode 100644 index 00000000..2b3647d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_avx_instructions.h @@ -0,0 +1,440 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X8_GETI(src, i) src.m256_f32[i] +#define MS256_F32_GETI(src, i) src.m256_f32[i] +#else +#include +#define MS_F32X8_GETI(src, i) src[i] +#define MS256_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X8 __m256 +#define MS_FLOAT256_F32 __m256 +#define MS_INT32X8 __m256i +#define MS_INT256_EPI32 __m256i +#define MS_MASK256_TYPE MS_FLOAT32X8 +#define MS_LD256_F32 _mm256_loadu_ps +#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD256_F32 _mm256_add_ps +#define MS_ADD256_EPI32 _mm256_add_epi32 +#define MS_MOV256_F32 _mm256_set1_ps +#define MS_MOV256_EPI32 _mm256_set1_epi32 +#define MS_MOV256_VAL0_F32 _mm256_setzero_ps() +#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1) +#define MS_ST256_F32 _mm256_storeu_ps +#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB256_F32 _mm256_sub_ps +#define MS_SUB256_EPI32 _mm256_sub_epi32 +#define MS_MAX256_F32 _mm256_max_ps +#define MS_MAX256_EPI32 _mm256_max_epi32 +#define MS_MIN256_F32 _mm256_min_ps +#define MS_MIN256_EPI32 _mm256_min_epi32 +#define MS_SQRT256_F32 _mm256_sqrt_ps +#define MS_RSQRT256_F32 _mm256_rsqrt_ps +#define MS_SIN256_F32 _mm256_sin_ps +#define MS_ERF256_F32 _mm256_erf_ps +#define MS_ROUND256_F32(src) _mm256_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR256_F32 _mm256_floor_ps +#define MS_CEIL256_F32 _mm256_ceil_ps +#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2) +#define MS_MUL256_EPI32(src1, src2) _mm256_mullo_epi32(src1, src2) +#define MS_FMADD256_F32(src1, src2, src3) _mm256_fmadd_ps(src1, src2, src3) +#define MS_FMSUB256_F32(src1, src2, src3) _mm256_fmsub_ps(src1, src2, src3) +#define MS_FSMUL256_F32(src1, src2, src3) _mm256_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2) +#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2)) +#define MS_MUL256_N_EPI32(src1, src2) _mm256_mullo_epi32(src1, _mm256_set1_epi32(src2)) +#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2)) +#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) +#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) +#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int +#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) +#define MS_CMPGT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 30) +#define MS_CMPLE256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 18) +#define MS_CMPLT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 17) +#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) +#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) +#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) +#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src) +#define MS_AND256_MASK(src1, src2) _mm256_and_ps(src1, src2) +#define MS_OR256_F32(src1, src2) _mm256_or_ps(src1, src2) +#define MS_AND256_MASK_F32(src1, src2) _mm256_and_ps(src1, src2) +#define MS_AND256_F32(src1, src2) _mm256_and_ps(src1, src2) + +#define MS256_ANDNOT_F32(src1, src2) _mm256_andnot_ps(src1, src2) +#define MS256_SRLI_EPI32(src1, src2) _mm256_srli_epi32(src1, src2) +#define MS256_AND_EPI32(src1, src2) _mm256_and_si256(src1, src2) +#define MS256_CASTPS_EPI32(src) _mm256_castps_si256(src) + +static inline MS_FLOAT32X8 MS_POW256_F32(MS_FLOAT32X8 src1, MS_FLOAT32X8 src2) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = powf(MS_F32X8_GETI(src1, 0), MS_F32X8_GETI(src2, 0)); + MS_F32X8_GETI(dst, 1) = powf(MS_F32X8_GETI(src1, 1), MS_F32X8_GETI(src2, 1)); + MS_F32X8_GETI(dst, 2) = powf(MS_F32X8_GETI(src1, 2), MS_F32X8_GETI(src2, 2)); + MS_F32X8_GETI(dst, 3) = powf(MS_F32X8_GETI(src1, 3), MS_F32X8_GETI(src2, 3)); + MS_F32X8_GETI(dst, 4) = powf(MS_F32X8_GETI(src1, 4), MS_F32X8_GETI(src2, 4)); + MS_F32X8_GETI(dst, 5) = powf(MS_F32X8_GETI(src1, 5), MS_F32X8_GETI(src2, 5)); + MS_F32X8_GETI(dst, 6) = powf(MS_F32X8_GETI(src1, 6), MS_F32X8_GETI(src2, 6)); + MS_F32X8_GETI(dst, 7) = powf(MS_F32X8_GETI(src1, 7), MS_F32X8_GETI(src2, 7)); + return dst; +} + +static inline MS_FLOAT32X8 MS_ABS256_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = fabsf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = fabsf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = fabsf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = fabsf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = fabsf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = fabsf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = fabsf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = fabsf(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline MS_FLOAT256_F32 SIMD_SIGN256_F32(MS_FLOAT256_F32 src) { + MS_FLOAT256_F32 abs_src = MS_ABS256_F32(src); + MS_FLOAT256_F32 sign = MS_DIV256_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS256_F32(src, abs_src) MS_DIV256_F32(abs_src, src) + +static inline MS_FLOAT32X8 MS_COS256_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 pi = {PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X8 pi2_neg = { + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + }; + static const MS_FLOAT32X8 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT256_F32 src_abs = MS_ABS256_F32(src); + MS_FLOAT256_F32 src_cycle = + MS_ADD256_F32(MS_MUL256_F32(MS_FLOOR256_F32(MS_MUL256_F32(MS_ADD256_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + + static const MS_FLOAT256_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT256_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT256_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT256_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT256_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X8 square = MS_MUL256_F32(src_cycle, src_cycle); + + MS_FLOAT32X8 tmp = + MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp, neg), square), data2); + MS_FLOAT256_F32 res = MS_ADD256_F32( + MS_MUL256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X8 MS256_LOG_F32(MS_FLOAT32X8 src) { + const MS_INT256_EPI32 gFloatExpMask = MS_MOV256_EPI32(0xffULL << 23); + const MS_INT256_EPI32 gFloatExp0 = MS_MOV256_EPI32(127ULL << 23); + const MS_INT256_EPI32 gExpNormalizer = MS_MOV256_EPI32(127); + static const MS_FLOAT256_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT256_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT256_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT256_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT256_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT256_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT256_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X8 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT256_EPI32 exps32 = MS256_SRLI_EPI32(MS256_AND_EPI32(gFloatExpMask, MS256_CASTPS_EPI32(src)), 23); + const MS_INT256_EPI32 normExps = MS_SUB256_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X8 expsPD = MS_CVT256EPI32_PS(normExps); + const MS_FLOAT32X8 y = + MS_OR256_F32(MS_CAST256_F32_S32(gFloatExp0), MS256_ANDNOT_F32(MS_CAST256_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X8 div = MS_DIV256_F32(MS_ADD256_F32(y, neg), MS_ADD256_F32(y, pos)); + MS_FLOAT32X8 square = MS_MUL256_F32(div, div); + + MS_FLOAT32X8 tmp = MS_ADD256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, tmp), data4)); + MS_FLOAT32X8 res = + MS_ADD256_F32(MS_MUL256_F32(ln2, expsPD), MS_MUL256_F32(MS_MUL256_F32(div, MS_ADD256_F32(tmp1, data5)), data6)); + MS_FLOAT32X8 mask = MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(-INFINITY), mask); + mask = MS_CMP256_F32(src, MS_MOV256_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(INFINITY), mask); + mask = MS_OR256_F32(MS_CMPLT256_F32(src, MS_MOV256_F32(0.0f)), MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_UNORD_Q)); + res = MS_BLEND256_F32(res, MS_MOV256_F32(NAN), mask); + return res; +} + +static inline float MS_GET_MAX256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = fmaxf(result, MS_F32X8_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = result + MS_F32X8_GETI(src, i); + } + return result; +} + +#define MS_DIV256_EPI32(src1, src2) \ + _mm256_cvttps_epi32(MS_DIV256_F32(_mm256_cvtepi32_ps(src1), _mm256_cvtepi32_ps(src2))) + +#define MS256_INT16_TO_FLOAT16(src) _mm256_cvtepi16_ph(src) +#define MS256_FLOAT16_TO_INT16(src) _mm256_cvttph_epi16(src) + +#define MS256_INT32_TO_FLOAT16(src) _mm256_cvtepi32_ph(src) +#define MS256_FLOAT16_TO_INT32(src) _mm256_cvttph_epi32(src) + +#define MS256_INT32_TO_FLOAT32(src) _mm256_cvtepi32_ps(src) +#define MS256_FLOAT32_TO_INT32(src) _mm256_cvttps_epi32(src) + +#define MS256_INT64_TO_FLOAT32(src) _mm256_cvtepi64_ps(src) +#define MS256_FLOAT32_TO_INT64(src) _mm256_cvttps_epi64(src) + +#define MS256_INT64_TO_FLOAT16(src) _mm256_cvtepi64_ph(src) +#define MS256_FLOAT16_TO_INT64(src) _mm256_cvttph_epi64(src) + +#define MS256_INT32_TO_FLOAT64(src) _mm256_cvtepi32_pd(src) +#define MS256_FLOAT64_TO_INT32(src) _mm256_cvttpd_epi32(src) + +#define MS256_INT64_TO_FLOAT64(src) _mm256_cvtepi64_pd(src) +#define MS256_FLOAT64_TO_INT64(src) _mm256_cvttpd_epi64(src) + +#define MS256_INT16_TO_INT32(src) _mm256_cvtepi16_epi32(src) +#define MS256_INT16_TO_INT64(src) _mm256_cvtepi16_epi64(src) +#define MS256_INT32_TO_INT16(src) _mm256_cvtepi32_epi16(src) +#define MS256_INT32_TO_INT64(src) _mm256_cvtepi32_epi64(src) +#define MS256_INT64_TO_INT16(src) _mm256_cvtepi64_epi16(src) +#define MS256_INT64_TO_INT32(src) _mm256_cvtepi64_epi32(src) + +static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_LOAD256X4_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); + +#define MS_LOAD256X8_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); + +#define MS_LOAD256X16_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \ + MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \ + MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \ + MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \ + MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \ + MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \ + MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \ + MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \ + MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num); + +#define STORE256X8_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); + +#define STORE256X16_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); \ + MS_ST256_F32(output_ptr + 8 * num, dst##9); \ + MS_ST256_F32(output_ptr + 9 * num, dst##10); \ + MS_ST256_F32(output_ptr + 10 * num, dst##11); \ + MS_ST256_F32(output_ptr + 11 * num, dst##12); \ + MS_ST256_F32(output_ptr + 12 * num, dst##13); \ + MS_ST256_F32(output_ptr + 13 * num, dst##14); \ + MS_ST256_F32(output_ptr + 14 * num, dst##15); \ + MS_ST256_F32(output_ptr + 15 * num, dst##16); + +static inline MS_FLOAT32X8 simd_exp256_f32(MS_FLOAT32X8 input) { + static MS_FLOAT32X8 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X8 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X8 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv)); + MS_INT32X8 integer = MS_CVT256PS_EPI32(MS_FLOOR256_F32(MS_FMADD256_F32(input, param[6], param[4]))); + MS_FLOAT32X8 decimal = MS_SUB256_F32(input, MS_MUL256_F32(MS_CVT256EPI32_PS(integer), param[0])); + MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(126)), 23); + MS_FLOAT32X8 tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X8 decimal_exp = MS_FMADD256_F32(decimal, tmp, param[5]); + return MS_MUL256_F32(param[7], MS_MUL256_F32(decimal_exp, MS_CAST256_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X8 simd_hexp256_f32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = exp(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = exp(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = exp(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = exp(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = exp(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = exp(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = exp(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = exp(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline void simd_exp256(MS_FLOAT32X8 input, float *dst) { + MS_FLOAT32X8 res = simd_exp256_f32(input); + MS_ST256_F32(dst, res); +} + +static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X8 square = MS_MUL256_F32(src, src); + MS_FLOAT32X8 a = + MS_MUL256_F32(MS_FMADD256_F32(MS_FMADD256_F32(MS_ADD256_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X8 b = + MS_FMADD256_F32(MS_FMADD256_F32(MS_FMADD256_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X8 res = MS_DIV256_F32(a, b); + MS_FLOAT32X8 up_limit = MS_MOV256_F32(5.0f); + MS_FLOAT32X8 down_limit = MS_MOV256_F32(-5.0f); + MS_FLOAT32X8 up_mask = MS_CMPGT256_F32(src, up_limit); + MS_FLOAT32X8 down_mask = MS_CMPLT256_F32(src, down_limit); + res = MS_BLEND256_F32(res, pos, up_mask); + res = MS_BLEND256_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH256_F32 MS_TANHX8_F32 + +static inline MS_FLOAT32X8 MS256_ERF_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = erff(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = erff(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = erff(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = erff(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = erff(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = erff(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = erff(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = erff(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_FMADD256X8_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA256_F32(dst##8, src##8, weight); + +#define MS_SET_ZERO256X8_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##8 = _mm256_setzero_ps(); + +#define MS_FMADD256X4_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); + +#define MS_SET_ZERO256X4_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); + +#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]); +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c new file mode 100644 index 00000000..8348f795 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.c @@ -0,0 +1,141 @@ + + +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include +#include +#include +#include "nnacl_c/errorcode.h" + +typedef unsigned int DWORD; +struct X86CpuInfoContext { + bool fma_flag_; + bool sse4_1_flag_; + bool avx2_flag_; + bool avx512_flag_; +}; + +static struct X86CpuInfoContext g_x86_cpu_info_context_; + +inline const bool X86_Fma_Support(void) { return g_x86_cpu_info_context_.fma_flag_; } + +inline const bool X86_Sse_Support(void) { +#ifdef ENABLE_SSE + return g_x86_cpu_info_context_.sse4_1_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx_Support(void) { +#ifdef ENABLE_AVX + return g_x86_cpu_info_context_.avx2_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx512_Support(void) { +#ifdef ENABLE_AVX512 + return g_x86_cpu_info_context_.avx512_flag_; +#else + return false; +#endif +} + +void ExecuteCpuIdCmd(DWORD cmd_code, DWORD *eax_data, DWORD *ebx_data, DWORD *ecx_data, DWORD *edx_data) { + DWORD deax, debx, decx, dedx; + asm volatile( + "movl %4, %%eax;\n" + "movl $0, %%ecx;\n" + "cpuid;\n" + "movl %%eax, %0;\n" + "movl %%ebx, %1;\n" + "movl %%ecx, %2;\n" + "movl %%edx, %3;\n" + : "=r"(deax), "=r"(debx), "=r"(decx), "=r"(dedx) + : "r"(cmd_code) + : "%eax", "%ebx", "%ecx", "%edx"); + + *eax_data = deax; + *ebx_data = debx; + *ecx_data = decx; + *edx_data = dedx; +} + +bool IsIntelX86Platform(void) { + DWORD eax_data, ebx_data, ecx_data, edx_data; + + const int vid_info_size = 13; + char *vid_info = malloc(sizeof(char) * vid_info_size); + if (vid_info == NULL) { + return false; + } + memset(vid_info, 0, vid_info_size); + + ExecuteCpuIdCmd(0, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 0, execute cpuid to get vid info + + memcpy(vid_info, &ebx_data, 4); // Copy the first 4 characters to the array[0:3] + memcpy(vid_info + 4, &edx_data, 4); // Copy the middle 4 characters to the array[4:8] + memcpy(vid_info + 8, &ecx_data, 4); // Copy the last 4 characters to the array[8:12] + + int x86_intel_flag = (strcmp(vid_info, "GenuineIntel") == 0 || strcmp(vid_info, "AuthenticAMD") == 0) ? 1 : 0; + + free(vid_info); + return x86_intel_flag; +} + +int IntelX86CpuInfoInit(void) { + if (!IsIntelX86Platform()) { + return NNACL_ERR; + } + DWORD eax_data, ebx_data, ecx_data, edx_data; + ExecuteCpuIdCmd(1, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 1, execute cpuid to get sse/fma flag + g_x86_cpu_info_context_.sse4_1_flag_ = (ecx_data & (1 << 19)) == 0 ? false : true; // sse flag is ecx 19 bit + g_x86_cpu_info_context_.fma_flag_ = (ecx_data & (1 << 12)) == 0 ? false : true; // fma flag is ecx 12 bit + + ExecuteCpuIdCmd(7, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 7, execute cpuid to get avx2/avx512 flag + g_x86_cpu_info_context_.avx2_flag_ = (ebx_data & (1 << 5)) == 0 ? false : true; // avx2 flag is ecx 5 bit + g_x86_cpu_info_context_.avx512_flag_ = (ebx_data & (1 << 16)) == 0 ? false : true; // avx512 flag is ecx 16 bit + + return NNACL_OK; +} + +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void) { + if (IntelX86CpuInfoInit() != NNACL_OK) { + return X86CPUINFO_PLATFORM_ERR; + } +#if defined(ENABLE_AVX512) && !defined(AVX512_HARDWARE_SELF_AWARENESS) + if (!X86_Avx512_Support()) { + return X86CPUINFO_AVX512_ERR; + } +#endif + +#ifdef ENABLE_AVX + if (!X86_Avx_Support()) { + return X86CPUINFO_AVX_ERR; + } +#endif + +#ifdef ENABLE_SSE + if (!X86_Sse_Support()) { + return X86CPUINFO_SSE_ERR; + } +#endif + return X86CPUINFO_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h new file mode 100644 index 00000000..cec5ef13 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_cpu_info.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MS_SIMD_CPU_INFO_H_ +#define NNACL_MS_SIMD_CPU_INFO_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_AVX512 +#define AVX512_HARDWARE_SELF_AWARENESS +#endif + +#if defined(AVX512_HARDWARE_SELF_AWARENESS) +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN if (X86_Avx512_Support()) { +#define AVX512_HARDWARE_SELF_AWARENESS_END } +#else +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN +#define AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +typedef enum X86CpuInfoErrorCodeEnum { + X86CPUINFO_OK = 0, + X86CPUINFO_PLATFORM_ERR = 1, + X86CPUINFO_AVX512_ERR, + X86CPUINFO_AVX_ERR, + X86CPUINFO_SSE_ERR, + X86CPUINFO_END = 9999 +} X86CpuInfoErrorCodeEnum; + +const bool X86_Fma_Support(void); +const bool X86_Sse_Support(void); +const bool X86_Avx_Support(void); +const bool X86_Avx512_Support(void); + +bool IsIntelX86Platform(void); +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void); + +int IntelX86CpuInfoInit(void); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h new file mode 100644 index 00000000..030c610a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h @@ -0,0 +1,563 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" + +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_avx512_instructions.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl_c/intrinsics/ms_simd_avx_instructions.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_sse_instructions.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl_c/intrinsics/ms_simd_neon_instructions.h" +#endif + +#define MS_SIMD_AVX512_INSTRUCTION(instruction, suffix) instruction##512##suffix +#define MS_SIMD_AVX_INSTRUCTION(instruction, suffix) instruction##256##suffix +#define MS_SIMD_SSE_INSTRUCTION(instruction, suffix) instruction##128##suffix +#define MS_SIMD_NEON_INSTRUCTION(instruction, suffix) instruction##128##suffix + +#define MS_SIMD_INSTRUCTION_F32(instruction) MS_SIMD_INSTRUCTION(instruction, _F32) +#define MS_SIMD_INSTRUCTION_EPI32(instruction) MS_SIMD_INSTRUCTION(instruction, _EPI32) +#define MS_SIMD_INSTRUCTION_MASK(instruction) MS_SIMD_INSTRUCTION(instruction, _MASK) + +// define (float/int) data +#define SIMD_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOAT) +#define SIMD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_INT) +#define SIMD_MASK MS_SIMD_INSTRUCTION(MS_MASK, _TYPE) + +// read scaler data +#define SIMD_F32_GETI MS_SIMD_INSTRUCTION(MS, _F32_GETI) + +// move (float/int) data +#define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV) +#define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV) +#define SIMD_SET0_F32 MS_SIMD_INSTRUCTION(MS_MOV, _VAL0_F32) + +// load (float/int) data +#define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD) +#define SIMD_LD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_LD) +#define SIMD_LD_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_LD, _HALF_EPI32) + +// load 4 (float/int) data +#define SIMD_LDX4_F32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_F32) +#define SIMD_LDX4_EPI32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_EPI32) + +// stored (float/int) data +#define SIMD_ST_F32 MS_SIMD_INSTRUCTION_F32(MS_ST) +#define SIMD_ST_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ST) +#define SIMD_ST_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_ST, _HALF_EPI32) + +// sign +#define SIMD_SIGN_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGN) +#define SIMD_SIGNABS_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGNABS) + +// add (float/int) op +#define SIMD_ADD_F32 MS_SIMD_INSTRUCTION_F32(MS_ADD) +#define SIMD_ADD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ADD) +#define SIMD_ADD_N_F32(val1, val2) MS_EXPAND(SIMD_ADD_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_ADD_N_EPI32(val1, val2) MS_EXPAND(SIMD_ADD_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sub (float/int) op +#define SIMD_SUB_F32 MS_SIMD_INSTRUCTION_F32(MS_SUB) +#define SIMD_SUB_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_SUB) +#define SIMD_SUB_N_F32(val1, val2) MS_EXPAND(SIMD_SUB_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_SUB_N_EPI32(val1, val2) MS_EXPAND(SIMD_SUB_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// div (float/int) op +#define SIMD_DIV_F32 MS_SIMD_INSTRUCTION_F32(MS_DIV) +#define SIMD_DIV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_DIV) +#define SIMD_DIV_N_F32(val1, val2) MS_EXPAND(SIMD_DIV_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_DIV_N_EPI32(val1, val2) MS_EXPAND(SIMD_DIV_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sqrt (float) op +#define SIMD_SQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_SQRT) + +// rsqrt (float) op +#define SIMD_RSQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_RSQRT) + +// log (float) op +#define SIMD_LOG_F32 MS_SIMD_INSTRUCTION(MS, _LOG_F32) + +// cos (float) op +#define SIMD_COS_F32 MS_SIMD_INSTRUCTION_F32(MS_COS) + +// sin (float) op +#define SIMD_SIN_F32 MS_SIMD_INSTRUCTION_F32(MS_SIN) + +// erf (float) op +#define SIMD_ERF_F32 MS_SIMD_INSTRUCTION(MS, _ERF_F32) + +// abs (float) op +#define SIMD_ABS_F32 MS_SIMD_INSTRUCTION_F32(MS_ABS) +#define SIMD_ABS_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ABS) + +// round (float) op +#define SIMD_ROUND_F32 MS_SIMD_INSTRUCTION_F32(MS_ROUND) + +// ceil (float) op +#define SIMD_CEIL_F32 MS_SIMD_INSTRUCTION_F32(MS_CEIL) + +// floor (float) op +#define SIMD_FLOOR_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOOR) + +// tanh (float) op +#define SIMD_TANH_F32 MS_SIMD_INSTRUCTION_F32(MS_TANH) + +// min (float/int) op +#define SIMD_MIN_F32 MS_SIMD_INSTRUCTION_F32(MS_MIN) +#define SIMD_MIN_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MIN) +#define SIMD_MIN_N_F32(val1, val2) MS_EXPAND(SIMD_MIN_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MIN_N_EPI32(val1, val2) MS_EXPAND(SIMD_MIN_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// max (float/int) op +#define SIMD_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_MAX) +#define SIMD_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MAX) +#define SIMD_MAX_N_F32(val1, val2) MS_EXPAND(SIMD_MAX_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MAX_N_EPI32(val1, val2) MS_EXPAND(SIMD_MAX_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// get max (float/int) op +#define SIMD_GET_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_MAX) +#define SIMD_GET_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_GET_MAX) + +// get max (float/int) op +#define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM) +#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32) + +// clamp (float/int) op +#define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val) +#define SIMD_CLAMP_EPI32(val, min_val, max_val) SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, min_val), max_val) +#define SIMD_CLAMP_N_F32(val, min_val, max_val) \ + SIMD_MIN_F32(SIMD_MAX_F32(val, SIMD_MOV_F32(min_val)), SIMD_MOV_F32(max_val)) +#define SIMD_CLAMP_N_EPI32(val, min_val, max_val) \ + SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, SIMD_MOV_EPI32(min_val)), SIMD_MOV_EPI32(max_val)) + +// mul (float/int) op +#define SIMD_MUL_F32 MS_SIMD_INSTRUCTION_F32(MS_MUL) +#define SIMD_MUL_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MUL) +#define SIMD_MUL_N_F32(val1, val2) MS_EXPAND(SIMD_MUL_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MUL_N_EPI32(val1, val2) MS_EXPAND(SIMD_MUL_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// pow (float) op +#define SIMD_POW_F32 MS_SIMD_INSTRUCTION_F32(MS_POW) + +// fma (float/int) op +#define SIMD_FMADD_F32 MS_SIMD_INSTRUCTION_F32(MS_FMADD) + +// fms (float/int) op +#define SIMD_FMSUB_F32 MS_SIMD_INSTRUCTION_F32(MS_FMSUB) + +// fsm (float) op +#define MS_FSMUL_F32 MS_SIMD_INSTRUCTION_F32(MS_FSMUL) + +// square (float/int) op +#define SIMD_MUL_SQUARE_F32(val1) SIMD_MUL_F32(val1, val1) +#define SIMD_MUL_SQUARE_EPI32(val1) SIMD_MUL_EPI32(val1, val1) + +// exp (float) op +#define SIMD_EXP_ST_F32 MS_SIMD_INSTRUCTION(simd_exp, ) +#define SIMD_EXP_F32 MS_SIMD_INSTRUCTION(simd_exp, _f32) +// exp (float) high precision but a little slow op. +#define SIMD_HEXP_F32 MS_SIMD_INSTRUCTION(simd_hexp, _f32) + +// cmp (float/int) op +#define SIMD_CMPLT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLT) +#define SIMD_CMPLE_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLE) +#define SIMD_CMPGT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPGT) +#define SIMD_BLEND_F32 MS_SIMD_INSTRUCTION_F32(MS_BLEND) + +// cast data +#define MS_CAST_F32_S32 MS_SIMD_INSTRUCTION(MS_CAST, _F32_S32) + +// logical op +#define SIMD_AND_MASK MS_SIMD_INSTRUCTION_MASK(MS_AND) +#define SIMD_OR_F32 MS_SIMD_INSTRUCTION_F32(MS_OR) +#define SIMD_AND_MASK_F32 MS_SIMD_INSTRUCTION(MS_AND, _MASK_F32) +#define SIMD_AND_F32 MS_SIMD_INSTRUCTION_F32(MS_AND) + +#define SIMD_GETSIGN_F32(src) \ + SIMD_OR_F32(SIMD_AND_F32(src, MS_CAST_F32_S32(SIMD_MOV_EPI32(0x80000000))), \ + MS_CAST_F32_S32(SIMD_MOV_EPI32(0x3F800000))) + +// int32/float mutual conversion +#define SIMD_EPI32_TO_F32 MS_SIMD_INSTRUCTION(MS, _INT32_TO_FLOAT32) +#define SIMD_F32_TO_EPI32 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_INT32) +#define SIMD_F16_TO_F32 MS_SIMD_INSTRUCTION(MS, _FLOAT16_TO_FLOAT32) +#define SIMD_F32_TO_F16 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_FLOAT16) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define SIMD_RUN_AVX512(function, index, ...) \ + do { \ + AVX512_HARDWARE_SELF_AWARENESS_BEGIN \ + index = function##AVX512(index, __VA_ARGS__); \ + AVX512_HARDWARE_SELF_AWARENESS_END \ + } while (0) +#else +#define SIMD_RUN_AVX512(function, index, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define SIMD_RUN_AVX(function, index, ...) index = function##AVX(index, __VA_ARGS__) +#else +#define SIMD_RUN_AVX(function, index, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define SIMD_RUN_SSE(function, index, ...) index = function##SSE(index, __VA_ARGS__) +#else +#define SIMD_RUN_SSE(function, index, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define SIMD_RUN_NEON(function, index, ...) index = function##NEON(index, __VA_ARGS__) +#else +#define SIMD_RUN_NEON(function, index, ...) +#endif + +#define SIMD_RUN_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + SIMD_RUN_NEON(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD_RUN_X86_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD512_BLOCK16 32 // SIMD : 512 = 16 x 32 +#define SIMD256_BLOCK16 16 // SIMD : 256 = 16 x 16 +#define SIMD128_BLOCK16 8 // SIMD : 128 = 16 x 8 + +#define SIMD512_BLOCK32 16 // SIMD : 512 = 32 x 16 +#define SIMD256_BLOCK32 8 // SIMD : 256 = 32 x 8 +#define SIMD128_BLOCK32 4 // SIMD : 128 = 32 x 4 + +#define SIMD512_BLOCK64 8 // SIMD : 512 = 64 x 8 +#define SIMD256_BLOCK64 4 // SIMD : 256 = 64 x 4 +#define SIMD128_BLOCK64 2 // SIMD : 128 = 64 x 2 + +#define MS_EXPAND(...) __VA_ARGS__ + +// Scaler +#define MS_FLOAT32X1 float +#define MS_INT32X1 int +#define MS_MOV32_F32(value) (value) +#define MS_MOV32_EPI32(value) (value) +#define MS_LD32_F32(address) (*(address)) +#define MS_LD32_EPI32(address) (*(address)) +#define MS_ST32_F32(address, value) (*(address) = (value)) +#define MS_ST32_EPI32(address, value) (*(address) = (value)) +#define MS_ADD32_F32(value1, value2) ((value1) + (value2)) +#define MS_ADD32_EPI32(value1, value2) ((value1) + (value2)) +#define MS_SUB32_F32(value1, value2) ((value1) - (value2)) +#define MS_SUB32_EPI32(value1, value2) ((value1) - (value2)) +#define MS_MUL32_F32(value1, value2) ((value1) * (value2)) +#define MS_MUL32_EPI32(value1, value2) ((value1) * (value2)) +#define MS_DIV32_F32(value1, value2) ((value1) / (value2)) +#define MS_DIV32_EPI32(value1, value2) ((value1) / (value2)) +#define MS_MIN32_F32(value1, value2) (fmin((value1), (value2))) +#define MS_MIN32_EPI32(value1, value2) ((value1) < (value2) ? (value1) : (value2)) +#define MS_MAX32_F32(value1, value2) (fmax((value1), (value2))) +#define MS_MAX32_EPI32(value1, value2) ((value1) > (value2) ? (value1) : (value2)) +#define MS_SQRT32_F32(value) (sqrt(value)) + +static inline float simd_exp32_f32(float data) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // Approximate calculation param +#ifdef _WIN32 + if (data < -88.0f) { + return 0.0f; + } else if (data > 88.0f) { + return 1.6516363e+38; // e^88 = 1.6516363e+38 + } +#else + data = + MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, data)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) +#endif + int integer = floor(data * 1.44269504088896341f + 0.5f); + float decimal = data - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + // Approximate calculation + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + return factor * int_exp.f * decimal_exp; +} + +// exp(x) = exp(n * ln(2) + r) = 2^n * exp(r) = 2 * 2^(n - 1) * exp(r) +static inline void simd_exp32(float src, float *dst) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f) + src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) + int integer = floor(src * 1.44269504088896341f + 0.5f); + float decimal = src - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + *dst = factor * int_exp.f * decimal_exp; +} + +// define (float/int) data +#define MS_FLOAT_32xN(byte_num) MS_FLOAT32##X##byte_num +#define MS_INT_32xN(byte_num) MS_INT32##X##byte_num + +// move (float/int) data +#define MS_MOVN_F32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_F32(__VA_ARGS__)) +#define MS_MOVN_EPI32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_EPI32(__VA_ARGS__)) + +// load (float/int) data +#define MS_LD_F32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_F32(__VA_ARGS__)) +#define MS_LD_EPI32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_EPI32(__VA_ARGS__)) + +// load 4 (float/int) data +#define MS_LDX4_F32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_F32(__VA_ARGS__)) +#define MS_LDX4_EPI32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_EPI32(__VA_ARGS__)) + +// stored (float/int) data +#define MS_ST_F32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_F32(__VA_ARGS__)) +#define MS_ST_EPI32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_EPI32(__VA_ARGS__)) + +// add (float/int) op +#define MS_ADD_F32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_F32(__VA_ARGS__)) +#define MS_ADD_EPI32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_EPI32(__VA_ARGS__)) +#define MS_ADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_ADD_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// sub (float/int) op +#define MS_SUB_F32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_F32(__VA_ARGS__)) +#define MS_SUB_EPI32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_EPI32(__VA_ARGS__)) +#define MS_SUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_SUB_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// div (float/int) op +#define MS_DIV_F32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_F32(__VA_ARGS__)) +#define MS_DIV_EPI32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_EPI32(__VA_ARGS__)) +#define MS_DIV_N_F32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_DIV_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// sqrt (float) op +#define MS_SQRT_F32(bit_num, ...) MS_EXPAND(MS_SQRT##bit_num##_F32(__VA_ARGS__)) + +// rsqrt (float) op +#define MS_RSQRT_F32(bit_num, ...) MS_EXPAND(MS_RSQRT##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_LOG_F32(bit_num, ...) MS_EXPAND(MS_LOG##bit_num##_F32(__VA_ARGS__)) + +// cos (float) op +#define MS_COS_F32(bit_num, ...) MS_EXPAND(MS_COS##bit_num##_F32(__VA_ARGS__)) + +// sin (float) op +#define MS_SIN_F32(bit_num, ...) MS_EXPAND(MS_SIN##bit_num##_F32(__VA_ARGS__)) + +// erf (float) op +#define MS_ERF_F32(bit_num, ...) MS_EXPAND(MS_ERF##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_ABS_F32(bit_num, ...) MS_EXPAND(MS_ABS##bit_num##_F32(__VA_ARGS__)) + +// round (float) op +#define MS_ROUND_F32(bit_num, ...) MS_EXPAND(MS_ROUND##bit_num##_F32(__VA_ARGS__)) + +// ceil (float) op +#define MS_CEIL_F32(bit_num, ...) MS_EXPAND(MS_CEIL##bit_num##_F32(__VA_ARGS__)) + +// floor (float) op +#define MS_FLOOR_F32(bit_num, ...) MS_EXPAND(MS_FLOOR##bit_num##_F32(__VA_ARGS__)) + +// min (float/int) op +#define MS_MIN_F32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_F32(__VA_ARGS__)) +#define MS_MIN_EPI32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MIN_N_F32(bit_num, val, n) MS_MIN_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MIN_N_EPI32(bit_num, val, n) MS_MIN_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) + +// max (float/int) op +#define MS_MAX_F32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_MAX_F32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_GET_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_SUM_F32(bit_num, ...) MS_EXPAND(MS_GET_SUM##bit_num##_F32(__VA_ARGS__)) + +// max n (float/int) op +#define MS_MAX_N_F32(bit_num, val, n) MS_MAX_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MAX_N_EPI32(bit_num, val, n) MS_MAX_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) +#define MS_CLAMP_F32(bit_num, val, min_val, max_val) MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, min_val), max_val) +#define MS_CLAMP_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, min_val), max_val) + +// clamp n (float/int) op +#define MS_CLAMP_N_F32(bit_num, val, min_val, max_val) \ + MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, MS_MOV##bit_num##_F32(min_val)), MS_MOV##bit_num##_F32(max_val)) +#define MS_CLAMP_N_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, MS_MOV##bit_num##_EPI32(min_val)), MS_MOV##bit_num##_EPI32(max_val)) + +// mul (float/int) op +#define MS_MUL_F32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_F32(__VA_ARGS__)) +#define MS_MUL_EPI32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MUL_N_F32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_MUL_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// fma (float/int) op +#define MS_FMADD_F32(bit_num, ...) MS_EXPAND(MS_FMADD##bit_num##_F32(__VA_ARGS__)) +#define MS_FMADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// fms (float/int) op +#define MS_FMSUB_F32(bit_num, ...) MS_EXPAND(MS_FMSUB##bit_num##_F32(__VA_ARGS__)) +#define MS_FMSUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMSUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// square (float/int) op +#define MS_MUL_SQUARE_F32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_F32(val, val))) +#define MS_MUL_SQUARE_EPI32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_EPI32(val, val))) + +// exp (float) op +#define MS_EXP_ST_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num(__VA_ARGS__))) +#define MS_EXP_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num##_f32(__VA_ARGS__))) + +#define MS_CMPLT_F32(bit_num, ...) MS_EXPAND((MS_CMPLT##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPLE_F32(bit_num, ...) MS_EXPAND((MS_CMPLE##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPGT_F32(bit_num, ...) MS_EXPAND((MS_CMPGT##bit_num##_F32(__VA_ARGS__))) +#define MS_BLEND_F32(bit_num, ...) MS_EXPAND((MS_BLEND##bit_num##_F32(__VA_ARGS__))) + +#define MS_INT16_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT16_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT16(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT16(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT32(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT64(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT64(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT64(__VA_ARGS__))) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define MS_SIMD_RUN_AVX512(function, ...) MS_EXPAND(function(512, 16, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX512(function, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define MS_SIMD_RUN_AVX(function, ...) MS_EXPAND(function(256, 8, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX(function, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSE(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSE(function, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define MS_SIMD_RUN_NEON(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_NEON(function, ...) +#endif + +// enable neon/sse +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSEORNEON128(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSEORNEON128(function, ...) +#endif + +// scalar (c style data) +#define MS_SIMD_RUN_SCALAR(function, ...) MS_EXPAND(function(32, 1, __VA_ARGS__)) + +#define MS_SIMD_RUN(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + } while (0) + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h new file mode 100644 index 00000000..f5536771 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#include +#include "nnacl_c/intrinsics/ms_simd_instructions.h" + +#if defined(ENABLE_ARM82_A32) +static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) { + float16x8_t dst; + asm volatile( + "vrecpe.f16 q14, %3\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vmul.f16 %0, %2, q14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "q14", "q15"); + return dst; +} + +static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) { + float16x4_t dst; + asm volatile( + "vrecpe.f16 d14, %3\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vmul.f16 %0, %2, d14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "d14", "d16"); + return dst; +} + +static inline float ms_vaddvq_f32(float32x4_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + return in[0] + in[1] + in[2] + in[3]; +} + +static inline float16_t ms_vmaxvq_f16(float16x8_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + float16_t dst = in[0]; + for (int i = 1; i < 8; ++i) { + dst = dst > in[i] ? dst : in[i]; + } + return dst; +} + +static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) { + float32x4_t dst; + asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { + float16x4_t dst; + asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +#define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) +#define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) ms_vaddvq_f32(src) +#else +#define MS_CVT_F32_F16(src) vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) +#define MS_MAXVQ_F16(src) vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) vaddvq_f32(src) +#endif + +#define MS_FLOAT16X8 float16x8_t +#define MS_FLOAT16X4 float16x4_t +#define MS_FLOAT16X4X4 float16x4x4_t +#define MS_FLOAT16X4X2 float16x4x2_t +#define MS_MOVQ_F16 vmovq_n_f16 +#define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val) +#define MS_ST_F16 vst1_f16 +#define MS_ST2_F16 vst2_f16 +#define MS_ST4_F16 vst4_f16 +#define MS_MINQ_F16 vminq_f16 +#define MS_MAXQ_F16 vmaxq_f16 +#define MS_LDQ_F16(ptr) vld1q_f16(ptr) +#define MS_LD_F16(ptr) vld1_f16(ptr) +#define MS_ADDQ_F16 vaddq_f16 +#define MS_SUBQ_F16 vsubq_f16 +#define MS_MULQ_F16 vmulq_f16 +#define MS_FMAQ_F16 vfmaq_f16 +#define MS_MULQ_N_F16(vector, scalar) vmulq_n_f16(vector, scalar) +#define MS_CMPGTQ_F16(src1, src2) vcgtq_f16(src1, src2) + +static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { + float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); + float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); + return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); +} + +static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + dst[4] = erff(src[4]); + dst[5] = erff(src[5]); + dst[6] = erff(src[6]); + dst[7] = erff(src[7]); + return dst; +} + +static inline float16x8_t MS_SQRTFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + dst[4] = sqrtf(src[4]); + dst[5] = sqrtf(src[5]); + dst[6] = sqrtf(src[6]); + dst[7] = sqrtf(src[7]); + return dst; +} + +static inline float16x4_t MS_SQRTFX4_F16(float16x4_t src) { + float16x4_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + return dst; +} + +static inline float32x4_t MS_VMLAL_F16(float16x4_t x, float16x4_t dy, float32x4_t sum) { + float32x4_t x_fp32 = MS_CVT_F32_F16(x); + float32x4_t dy_fp32 = MS_CVT_F32_F16(dy); + return vmlaq_f32(sum, x_fp32, dy_fp32); +} + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h new file mode 100644 index 00000000..53333c7f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_neon_instructions.h @@ -0,0 +1,362 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#include + +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#define MS_FLOAT32X4 float32x4_t +#define MS_FLOAT32X4X2 float32x4x2_t +#define MS_FLOAT32X4X4 float32x4x4_t +#define MS_FLOAT128_F32 float32x4_t +#define MS_INT32X4 int32x4_t +#define MS_INT128_EPI32 int32x4_t +#define MS_UINT32X4 uint32x4_t +#define MS_MASK128_TYPE MS_UINT32X4 +#define MS_LDQ_F32 vld1q_f32 +#define MS_LD128_F32 vld1q_f32 +#define MS_LDQ_EPI32 vld1q_s32 +#define MS_LD128_EPI32 vld1q_s32 +#define MS_ADDQ_F32 vaddq_f32 +#define MS_ADD128_F32 vaddq_f32 +#define MS_ADDQ_EPI32 vaddq_s32 +#define MS_ADD128_EPI32 vaddq_s32 +#define MS_MOVQ_F32 vmovq_n_f32 +#define MS_MOV128_F32 vmovq_n_f32 +#define MS_MOVQ_EPI32 vmovq_n_s32 +#define MS_MOV128_VAL0_F32 vmovq_n_f32(0.0f) +#define MS_MOV128_EPI32 vmovq_n_s32 +#define MS_SUBQ_F32 vsubq_f32 +#define MS_SUB128_F32 vsubq_f32 +#define MS_SUB128_EPI32 vsubq_s32 +#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) +#define MS_STQ_F32 vst1q_f32 +#define MS_ST128_F32 vst1q_f32 +#define MS_STQ_EPI32 vst1q_s32 +#define MS_ST128_EPI32 vst1q_s32 +#define MS_MAXQ_F32 vmaxq_f32 +#define MS_MAXQ_EPI32 vmaxq_s32 +#define MS_MAX128_F32 vmaxq_f32 +#define MS_MAX128_EPI32 vmaxq_s32 +#define MS_MINQ_F32 vminq_f32 +#define MS_MINQ_EPI32 vminq_s32 +#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_MIN128_F32 vminq_f32 +#define MS_MIN128_EPI32 vminq_s32 +#define MS_MUL128_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MUL128_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_FMADD128_F32(src1, src2, src3) vmlaq_f32(src3, src1, src2) +#define MS_FSMUL128_F32(src1, src2, src3) vmlsq_f32(src1, src2, src3) +#define MS_FMSUB128_EPI32(src1, src2, src3) vmlsq_s32(src3, src1, src2) +#ifdef ENABLE_ARM64 +#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) +#define MS_DIV128_F32(src1, src2) vdivq_f32(src1, src2) +#else +static inline float32x4_t vrecp(float32x4_t v) { + float32x4_t r = vrecpeq_f32(v); + r = vmulq_f32(vrecpsq_f32(v, r), r); + r = vmulq_f32(vrecpsq_f32(v, r), r); + return r; +} +#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#define MS_DIV128_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#endif +#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) +#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2) +#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2) +#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) +#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) +#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) +#define MS_CMPLEQ_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) +#define MS_CMPLE128_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPLT128_F32(src1, src2) vcltq_f32(src1, src2) +#define MS_CMPGT128_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) vcgtq_s32(src1, src2) +// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32 +#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_BLEND128_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLEND128_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_CAST128_F32_S32(src) vreinterpretq_f32_s32(src) +#define MS_AND128_MASK(src1, src2) vandq_u32(src1, src2) +#define MS_AND128_F32(src1, src2) \ + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_OR128_F32(src1, src2) \ + vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_CAST128_U32_F32(src) vreinterpretq_u32_f32(src) +#define MS_CAST128_F32_U32(src) vreinterpretq_f32_u32(src) +#define MS_OR128_MASK(src1, src2) vorrq_u32(src1, src2) + +#ifdef ENABLE_ARM64 +#define MS_GET_MAX128_F32 vmaxvq_f32 +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { return vaddvq_f32(src); } +#else +static inline float MS_GET_MAX128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} +#endif + +static inline MS_FLOAT32X4 MS_AND128_MASK_F32(MS_UINT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 result; + result[0] = (src1[0] == 0) ? 0.0f : src2[0]; + result[1] = (src1[1] == 0) ? 0.0f : src2[1]; + result[2] = (src1[2] == 0) ? 0.0f : src2[2]; + result[3] = (src1[3] == 0) ? 0.0f : src2[3]; + return result; +} + +static inline int32x4_t MS_DIV128_EPI32(int32x4_t src1, int32x4_t src2) { + int32x4_t result; + result[0] = src1[0] / src2[0]; // C0 : 0 + result[1] = src1[1] / src2[1]; // C1 : 1 + result[2] = src1[2] / src2[2]; // C2 : 2 + result[3] = src1[3] / src2[3]; // C3 : 3 + return result; +} + +#define MS128_INT32_TO_FLOAT32(src) vcvtq_f32_s32(src) +#define MS128_FLOAT32_TO_INT32(src) vcvtq_s32_f32(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = logf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = logf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = logf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = logf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRT128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} +#define MS_RSQRT128_F32 vrsqrteq_f32 + +#define LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + static MS_FLOAT32X4 negative_flag = {-0.0f, -0.0f, -0.0f, -0.0f}; + + MS_INT32X4 integer = + MS_CVTQPS_EPI32(MS_FMADD128_F32(input, param[6], MS_OR128_F32(MS_AND128_F32(input, negative_flag), param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 up_limit = {5.0f, 5.0f, 5.0f, 5.0f}; + static const MS_FLOAT32X4 down_limit = {-5.0f, -5.0f, -5.0f, -5.0f}; + + MS_UINT32X4 up_mask = MS_CMPGTQ_F32(src, up_limit); + MS_UINT32X4 down_mask = MS_CMPGTQ_F32(down_limit, src); + + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = MS_MULQ_F32( + MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src); + MS_FLOAT32X4 b = MS_ADDQ_F32( + MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square), + data2); + + MS_FLOAT32X4 tanh_value = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 res = MS_BLENDQ_F32(tanh_value, pos, up_mask); + res = MS_BLENDQ_F32(res, neg, down_mask); + return res; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src_tmp); + return sign; +} + +static inline MS_FLOAT128_F32 SIMD_SIGNABS128_F32(MS_FLOAT128_F32 src, MS_FLOAT128_F32 abs_src) { + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + return MS_DIV128_F32(abs_src, src_tmp); +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h new file mode 100644 index 00000000..6eb07e25 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_sse_instructions.h @@ -0,0 +1,403 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X4_GETI(src, i) src.m128_f32[i] +#define MS128_F32_GETI(src, i) src.m128_f32[i] +#else +#include +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X4 __m128 +#define MS_FLOAT128_F32 __m128 +#define MS_INT32X4 __m128i +#define MS_INT128_EPI32 __m128i +#define MS_MASK128_TYPE MS_FLOAT32X4 +#define MS_LDQ_F32 _mm_loadu_ps +#define MS_LD128_F32 _mm_loadu_ps +#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_LD128_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_ADDQ_F32 _mm_add_ps +#define MS_ADD128_F32 _mm_add_ps +#define MS_ADDQ_EPI32 _mm_add_epi32 +#define MS_ADD128_EPI32 _mm_add_epi32 +#define MS_MOVQ_F32 _mm_set1_ps +#define MS_MOV128_F32 _mm_set1_ps +#define MS_MOVQ_EPI32 _mm_set1_epi32 +#define MS_MOV128_EPI32 _mm_set1_epi32 +#define MS_MOV128_VAL0_F32 _mm_setzero_ps() +#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) +#define MS_STQ_F32 _mm_storeu_ps +#define MS_ST128_F32 _mm_storeu_ps +#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_ST128_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_SUBQ_F32 _mm_sub_ps +#define MS_SUB128_F32 _mm_sub_ps +#define MS_SUB128_EPI32 _mm_sub_epi32 +#define MS_MAXQ_F32 _mm_max_ps +#define MS_MAXQ_EPI32 _mm_max_epi32 +#define MS_MAX128_F32 _mm_max_ps +#define MS_MAX128_EPI32 _mm_max_epi32 +#define MS_MINQ_F32 _mm_min_ps +#define MS_MINQ_EPI32 _mm_min_epi32 +#define MS_SQRT128_F32 _mm_sqrt_ps +#define MS_RSQRT128_F32 _mm_rsqrt_ps +#define MS_SIN128_F32 _mm_sin_ps +#define MS_ERF128_F32 _mm_erf_ps +#define MS_ROUND128_F32(src) _mm_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR128_F32 _mm_floor_ps +#define MS_CEIL128_F32 _mm_ceil_ps +#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MULQ_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_MIN128_F32 _mm_min_ps +#define MS_MIN128_EPI32 _mm_min_epi32 +#define MS_MUL128_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MUL128_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_DIV128_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2)) +#define MS_MULQ_N_EPI32(src1, src2) _mm_mullo_epi32(src1, _mm_set1_epi32(src2)) +#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2)) +#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) +#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int +#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CMPLEQ_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CMPLT128_F32(src1, src2) _mm_cmplt_ps(src1, src2) +#define MS_CMPLE128_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGT128_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPEQ128_F32(src1, src2) _mm_cmpeq_ps(src1, src2) +#define MS_CMPUNORD128_F32(src1, src2) _mm_cmpunord_ps(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLEND128_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLEND128_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) +#define MS_DIV128_EPI32(src1, src2) _mm_cvttps_epi32(MS_DIV128_F32(_mm_cvtepi32_ps(src1), _mm_cvtepi32_ps(src2))) +#define MS_AND128_MASK(src1, src2) _mm_and_ps(src1, src2) +#define MS_OR128_F32(src1, src2) _mm_or_ps(src1, src2) +#define MS_AND128_MASK_F32(src1, src2) _mm_and_ps(src1, src2) +#define MS_AND128_F32(src1, src2) _mm_and_ps(src1, src2) + +#define MS128_ANDNOT_F32(src1, src2) _mm_andnot_ps(src1, src2) +#define MS128_SRLI_EPI32(src1, src2) _mm_srli_epi32(src1, src2) +#define MS128_AND_EPI32(src1, src2) _mm_and_si128(src1, src2) +#define MS128_CASTPS_EPI32(src) _mm_castps_si128(src) +#define MS_CVT128EPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +#ifdef ENABLE_AVX // only enable sse, dont support fma instruction. +#define MS_FMADD128_F32(src1, src2, src3) _mm_fmadd_ps(src1, src2, src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_fmsub_ps(src1, src2, src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_fnmadd_ps(src3, src2, src1) +#else +#define MS_FMADD128_F32(src1, src2, src3) _mm_add_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_sub_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_sub_ps(src1, _mm_mul_ps(src2, src3)) +#endif + +#define MS128_INT16_TO_FLOAT16(src) _mm_cvtepi16_ph(src) +#define MS128_FLOAT16_TO_INT16(src) _mm_cvttph_epi16(src) + +#define MS128_INT32_TO_FLOAT16(src) _mm_cvtepi32_ph(src) +#define MS128_FLOAT16_TO_INT32(src) _mm_cvttph_epi32(src) + +#define MS128_INT32_TO_FLOAT32(src) _mm_cvtepi32_ps(src) +#define MS128_FLOAT32_TO_INT32(src) _mm_cvttps_epi32(src) + +#define MS128_INT64_TO_FLOAT32(src) _mm_cvtepi64_ps(src) +#define MS128_FLOAT32_TO_INT64(src) _mm_cvttps_epi64(src) + +#define MS128_INT64_TO_FLOAT16(src) _mm_cvtepi64_ph(src) +#define MS128_FLOAT16_TO_INT64(src) _mm_cvttph_epi64(src) + +#define MS128_INT32_TO_FLOAT64(src) _mm_cvtepi32_pd(src) +#define MS128_FLOAT64_TO_INT32(src) _mm_cvttpd_epi32(src) + +#define MS128_INT64_TO_FLOAT64(src) _mm_cvtepi64_pd(src) +#define MS128_FLOAT64_TO_INT64(src) _mm_cvttpd_epi64(src) + +#define MS128_INT16_TO_INT32(src) _mm128_cvtepi16_epi32(src) +#define MS128_INT16_TO_INT64(src) _mm128_cvtepi16_epi64(src) +#define MS128_INT32_TO_INT16(src) _mm128_cvtepi32_epi16(src) +#define MS128_INT32_TO_INT64(src) _mm128_cvtepi32_epi64(src) +#define MS128_INT64_TO_INT16(src) _mm128_cvtepi64_epi16(src) +#define MS128_INT64_TO_INT32(src) _mm128_cvtepi64_epi32(src) + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS128_F32(src, abs_src) MS_DIV128_F32(abs_src, src) + +static inline MS_FLOAT32X4 MS_COS128_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 pi = {PI, PI, PI, PI}; + static const MS_FLOAT32X4 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X4 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT128_F32 src_abs = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_cycle = + MS_ADD128_F32(MS_MUL128_F32(MS_FLOOR128_F32(MS_MUL128_F32(MS_ADD128_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT128_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT128_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT128_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT128_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT128_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MUL128_F32(src_cycle, src_cycle); + + MS_FLOAT32X4 tmp = + MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp, neg), square), data2); + MS_FLOAT128_F32 res = MS_ADD128_F32( + MS_MUL128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + const MS_INT128_EPI32 gFloatExpMask = MS_MOV128_EPI32(0xffULL << 23); + const MS_INT128_EPI32 gFloatExp0 = MS_MOV128_EPI32(127ULL << 23); + const MS_INT128_EPI32 gExpNormalizer = MS_MOV128_EPI32(127); + static const MS_FLOAT128_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT128_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT128_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT128_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT128_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT128_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT128_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 ln2 = {LN2, LN2, LN2, LN2}; + + const MS_INT128_EPI32 exps32 = MS128_SRLI_EPI32(MS128_AND_EPI32(gFloatExpMask, MS128_CASTPS_EPI32(src)), 23); + const MS_INT128_EPI32 normExps = MS_SUB128_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X4 expsPD = MS_CVT128EPI32_PS(normExps); + const MS_FLOAT32X4 y = + MS_OR128_F32(MS_CAST128_F32_S32(gFloatExp0), MS128_ANDNOT_F32(MS_CAST128_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X4 div = MS_DIV128_F32(MS_ADD128_F32(y, neg), MS_ADD128_F32(y, pos)); + MS_FLOAT32X4 square = MS_MUL128_F32(div, div); + + MS_FLOAT32X4 tmp = MS_ADD128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, tmp), data4)); + MS_FLOAT32X4 res = + MS_ADD128_F32(MS_MUL128_F32(ln2, expsPD), MS_MUL128_F32(MS_MUL128_F32(div, MS_ADD128_F32(tmp1, data5)), data6)); + MS_FLOAT32X4 mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(0.0f)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(-INFINITY), mask); + mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(INFINITY)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(INFINITY), mask); + mask = MS_OR128_F32(MS_CMPLT128_F32(src, MS_MOV128_F32(0.0f)), MS_CMPUNORD128_F32(src, MS_MOV128_F32(0.0f))); + res = MS_BLEND128_F32(res, MS_MOV128_F32(NAN), mask); + return res; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline float MS_GET_MAX128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + MS_INT32X4 integer = MS_CVTQPS_EPI32(MS_FLOOR128_F32(MS_FMADD128_F32(input, param[6], param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = + MS_MUL128_F32(MS_FMADD128_F32(MS_FMADD128_F32(MS_ADD128_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X4 b = + MS_FMADD128_F32(MS_FMADD128_F32(MS_FMADD128_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X4 res = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 up_limit = MS_MOV128_F32(5.0f); + MS_FLOAT32X4 down_limit = MS_MOV128_F32(-5.0f); + MS_FLOAT32X4 up_mask = MS_CMPGT128_F32(src, up_limit); + MS_FLOAT32X4 down_mask = MS_CMPLT128_F32(src, down_limit); + res = MS_BLEND128_F32(res, pos, up_mask); + res = MS_BLEND128_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c new file mode 100644 index 00000000..21f356a2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32IndirectRow.c @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl_c/fp32/conv_depthwise_fp32.h" + +#define INPUT_SIZE 25 + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6) { + input_stride /= sizeof(float *); + size_t c8 = UP_DIV(channels, C8NUM) * C8NUM; + size_t c8_mod = channels % C8NUM; + float *in[INPUT_SIZE]; + for (int i = 0; i < output_width; ++i) { + for (int k = 0; k < INPUT_SIZE; k++) { + in[k] = input[k]; + } + input += input_stride; + size_t c = c8; + const float *w = weights; + const float *bias1 = bias; + for (; c >= C8NUM; c -= C8NUM) { + __m256 out1 = _mm256_loadu_ps(bias1); + bias1 += 8; + for (int k = 0; k < INPUT_SIZE; k += 5) { + __m256 in1 = _mm256_loadu_ps(in[k]); + __m256 w1 = _mm256_loadu_ps(w); + __m256 in2 = _mm256_loadu_ps(in[k + 1]); + __m256 w2 = _mm256_loadu_ps(w + 8); + out1 = _mm256_fmadd_ps(in1, w1, out1); + __m256 in3 = _mm256_loadu_ps(in[k + 2]); + __m256 w3 = _mm256_loadu_ps(w + 16); + out1 = _mm256_fmadd_ps(in2, w2, out1); + __m256 in4 = _mm256_loadu_ps(in[k + 3]); + __m256 w4 = _mm256_loadu_ps(w + 24); + out1 = _mm256_fmadd_ps(in3, w3, out1); + __m256 in5 = _mm256_loadu_ps(in[k + 4]); + __m256 w5 = _mm256_loadu_ps(w + 32); + out1 = _mm256_fmadd_ps(in4, w4, out1); + w += 40; + in[k] += C8NUM; + in[k + 1] += C8NUM; + in[k + 2] += C8NUM; + in[k + 3] += C8NUM; + in[k + 4] += C8NUM; + out1 = _mm256_fmadd_ps(in5, w5, out1); + } + if (relu6 != 0) { + __m256 relu6_data = _mm256_set1_ps(6.0); + out1 = _mm256_min_ps(out1, relu6_data); + } + if (relu != 0 || relu6 != 0) { + __m256 zero = _mm256_setzero_ps(); + out1 = _mm256_max_ps(out1, zero); + } + if (c > C8NUM || c8_mod == 0) { + _mm256_storeu_ps(output, out1); + output += C8NUM; + } else { + __m128 tmp; + switch (c8_mod) { + case 1: + _mm_store_ss(output, _mm256_castps256_ps128(out1)); + break; + case 2: + _mm_storel_pi((__m64 *)output, _mm256_castps256_ps128(out1)); + break; + case 3: + tmp = _mm256_castps256_ps128(out1); + _mm_storel_pi((__m64 *)output, tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 2, tmp); + break; + case 4: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + break; + case 5: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_store_ss(output + 4, _mm256_extractf128_ps(out1, 1)); + break; + case 6: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_storel_pi((__m64 *)(output + 4), _mm256_extractf128_ps(out1, 1)); + break; + case 7: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + tmp = _mm256_extractf128_ps(out1, 1); + _mm_storel_pi((__m64 *)(output + 4), tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 6, tmp); + break; + default: + _mm256_storeu_ps(output, out1); + break; + } + output += c8_mod; + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c new file mode 100644 index 00000000..09b7e956 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/ConvDwFp32Row_sse.c @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + size_t out_c16 = DOWN_DIV(output_channel, C16NUM) * C16NUM; + size_t out_c8 = DOWN_DIV(output_channel, C8NUM) * C8NUM; + size_t out_c4 = DOWN_DIV(output_channel, C4NUM) * C4NUM; + for (int i = 0; i < num_pixels; i++) { + const float *weight_tmp = weight_ptr; + const float *input_tmp = input_ptr; + size_t out_c = 0; + for (; out_c < out_c16; out_c += C16NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 dst3 = _mm_loadu_ps(output_ptr + 8); + __m128 dst4 = _mm_loadu_ps(output_ptr + 12); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 w3 = _mm_loadu_ps(weight_tmp + 8); + __m128 w4 = _mm_loadu_ps(weight_tmp + 12); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + __m128 in3 = _mm_loadu_ps(input_tmp + 8); + __m128 in4 = _mm_loadu_ps(input_tmp + 12); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + dst3 = MS_MLAQ_F32(dst3, w3, in3); + dst4 = MS_MLAQ_F32(dst4, w4, in4); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + _mm_storeu_ps(output_ptr + 8, dst3); + _mm_storeu_ps(output_ptr + 12, dst4); + output_ptr += 16; + input_tmp += 16; + weight_tmp += 16; + } + for (; out_c < out_c8; out_c += C8NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + output_ptr += 8; + input_tmp += 8; + weight_tmp += 8; + } + for (; out_c < out_c4; out_c += C4NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 in1 = _mm_loadu_ps(input_tmp); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + _mm_storeu_ps(output_ptr, dst1); + output_ptr += 4; + input_tmp += 4; + weight_tmp += 4; + } + for (; out_c < output_channel; out_c++) { + *output_ptr++ += weight_ptr[out_c] * input_ptr[out_c]; + } + input_ptr += input_step; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c new file mode 100644 index 00000000..82832584 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/DepthwiseFp32_Sse.c @@ -0,0 +1,327 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +#ifndef ENABLE_AVX +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + kernel_w_step /= sizeof(float); + + const float *src_kh = src; + const float *weight_kh = weight; + __m128 dst_ma = _mm_setzero_ps(); + + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + + int c1 = 0; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step); + __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step); + + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3); + __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4); + + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma3); + dst_ma = _mm_add_ps(dst_ma, mul_ma4); + + src_kw += in_kw_step * 4; + weight_kw += C4NUM * 4; + } + + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + + src_kw += in_kw_step * 2; + weight_kw += C4NUM * 2; + } + + // remaining + for (; c1 < width; ++c1) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } + + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } + + __m128 bias_ma = _mm_loadu_ps(bias); + dst_ma = _mm_add_ps(dst_ma, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_ma = _mm_max_ps(zero_ma, dst_ma); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_ma = _mm_min_ps(const_ma, dst_ma); + } + } + _mm_storeu_ps(dst, dst_ma); +} +#endif + +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + __m128 dst_w_ma3 = _mm_setzero_ps(); + __m128 dst_w_ma4 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + + __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + + __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); + + ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); + _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); + } // dst_width loop + + // c2 loop + for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + } // kernel_w loop + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + + ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + } + + // remaining + for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + ActBlock1(&dst_w_ma1, relu, relu6); + _mm_storeu_ps(dst_w, dst_w_ma1); + } + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + __m128 src_w_ma = _mm_loadu_ps(src_w); + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + + int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM; + int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3); + + __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4); + + dst_kw += 4 * in_kw_step; + weight_kw += 4 * C4NUM; + } + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + dst_kw += 2 * in_kw_step; + weight_kw += 2 * C4NUM; + } + // remaining + for (; c1 < kernel_w; ++c1) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c new file mode 100644 index 00000000..17103cdb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/MatMul_Sse.c @@ -0,0 +1,243 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" +#include "nnacl_c/base/minimal_filtering_generator.h" + +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + const float *src1 = matix_a; + int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; + int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; + for (int i = 0; i < m; ++i) { + const float *src1_n = src1; + const float *src2_n = matrix_b; + for (int j = 0; j < n; ++j) { + const float *src1_j = src1_n; + int y = 0; + // 16 channel + for (; y < c16; y += C16NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + __m128 ma3 = _mm_loadu_ps(src1_j + 8); + __m128 ma4 = _mm_loadu_ps(src1_j + 12); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + __m128 tmp3 = _mm_mul_ps(ma3, mb); + __m128 tmp4 = _mm_mul_ps(ma4, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + _mm_storeu_ps(matrix_c + 8, dst3); + _mm_storeu_ps(matrix_c + 12, dst4); + src1_j -= in_channel * k; + src1_j += C16NUM; + matrix_c += C16NUM; + } + // 8 channel + for (; y < c8; y += C8NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + src1_j -= in_channel * k; + src1_j += C8NUM; + matrix_c += C8NUM; + } + // remain chann + for (; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + *matrix_c++ = tmp; + } + src2_n += 1; + } + src1 += k * in_channel; + } +} + +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode) { + int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b, *bias_d = bias; + float *dst = NULL; + for (int cc = col; cc > 0; cc -= C8NUM) { + if (write_mode != 0) { // writec8 + dst = c; + } + const float *srca_d = a; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = depth; d > 0; --d) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + + if (bias != NULL) { + DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); + bias_d += C8NUM; + } + + ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); + + if (write_mode == OutType_TileC8) { // WriteWino + c = dst + WinoSteps2; + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); + } else if (write_mode == OutType_C8) { // WriteC8 + _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2); + _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4); + _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6); + _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8); + c += C8Steps; + } else { + switch (cc) { + case 1: // write1 + c = dst + 1; + WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r); + break; + case 2: // write2 + c = dst + 2; + WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); + break; + case 3: // write3 + c = dst + 3; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst1); + WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r); + break; + case 4: // write4 + c = dst + 4; + WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r); + break; + case 5: // write5 + c = dst + 5; + WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r); + break; + case 6: // write6 + c = dst + 6; + WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r); + break; + case 7: // write7 + c = dst + 7; + WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r); + break; + default: // write8 + c = dst + C8NUM; + WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r); + break; + } + } + if (cc <= C8NUM) break; // write end + } + a += C4NUM * depth; + if (write_mode == OutType_C8) c += 32; + if (write_mode == OutType_TileC8) c = dst + WinoSteps2; + if (write_mode == OutType_Nhwc) c = dst - col; + if (r <= C4NUM) break; + } +} + +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) { + for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { + const float *srca_d = a; + float *dst = c; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = 0; d < depth; d++) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); + dst += 32; + c = dst; + } + b += depth * C8NUM; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c new file mode 100644 index 00000000..0526761b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/PostFuncBiasReluC8.c @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + for (int loop_c8 = 0; loop_c8 != oc8div; loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m128 bias1 = _mm_setzero_ps(), bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM, src += 32) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + __m128 src5 = _mm_loadu_ps(src + 16); + __m128 src6 = _mm_loadu_ps(src + 20); + __m128 src7 = _mm_loadu_ps(src + 24); + __m128 src8 = _mm_loadu_ps(src + 28); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias2); + src5 = _mm_add_ps(src5, bias1); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias1); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src3); + _mm_storeu_ps(dst_c8 + 4, src4); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src5); + _mm_storeu_ps(dst_c8 + 4, src6); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src7); + _mm_storeu_ps(dst_c8 + 4, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c8 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + } + } + + if (oc8mod == 0) return; + + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c1 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + switch (oc8mod) { + case 1: + _mm_store_ss(dst_c1, src1); + break; + case 2: + _mm_storel_pi((__m64 *)(dst_c1), src1); + break; + case 3: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + 2, src1); + break; + case 4: + _mm_storeu_ps(dst_c1, src1); + break; + case 5: + _mm_storeu_ps(dst_c1, src1); + _mm_store_ss(dst_c1 + 4, src2); + break; + case 6: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + break; + case 7: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + src2 = _mm_unpackhi_ps(src2, src2); + _mm_store_ss(dst_c1 + 6, src2); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c new file mode 100644 index 00000000..0460476f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/TiledC4MatMulFp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + const __m128 weight, const float v1, const float v2, const float v3, + const float v4) { + *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1))); + *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2))); + *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3))); + *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4))); +} + +static inline void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4, + const float *src) { + *src1 = _mm_loadu_ps(src); + *src2 = _mm_loadu_ps(src + 4); + *src3 = _mm_loadu_ps(src + 8); + *src4 = _mm_loadu_ps(src + 12); +} + +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + const float *src_tmp = src; + for (int i = 0; i < oc4; ++i) { + float *dst_tmp = dst; + src = src_tmp; + size_t ic4_tmp = ic4 - 1; + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += 16; + __m128 weight_data[4]; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight_data[2] = _mm_loadu_ps(weight + 8); + weight_data[3] = _mm_loadu_ps(weight + 12); + weight += 16; + __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + if (ic4_tmp != 0) { + ic4_tmp -= 1; + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + for (; ic4_tmp != 0; ic4_tmp -= 1) { + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3)))); + src1 = _mm_loadu_ps(src); + src2 = _mm_loadu_ps(src + 4); + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3)))); + src3 = _mm_loadu_ps(src + 8); + src4 = _mm_loadu_ps(src + 12); + src += 16; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0), + MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + } + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + for (int j = 0; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + } + _mm_storeu_ps(dst, dst1); + _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3); + _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5); + _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7); + _mm_storeu_ps(dst + 28, dst8); + dst = dst_tmp + cal_num; + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c new file mode 100644 index 00000000..7c382e93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradPostFuncBiasReluC4.c @@ -0,0 +1,349 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/intrinsics/sse/sse_common.h" + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc4div + oc4mod; + plane_stride /= sizeof(float); + int loop_c4 = 0; + size_t src_stride = plane_size * C4NUM + plane_stride; + for (; loop_c4 <= (int)(oc4div)-C16NUM; loop_c4 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + __m128 bias4 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias4 = _mm_loadu_ps(bias + C12NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src13 = _mm_loadu_ps(src + src_stride * C3NUM); + __m128 src14 = _mm_loadu_ps(src + src_stride * C3NUM + C4NUM); + + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src13 = _mm_add_ps(src13, bias4); + src14 = _mm_add_ps(src14, bias4); + + ActBlock8(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + _mm_storeu_ps(dst_c4 + C12NUM, src13); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + _mm_storeu_ps(dst_c4 + C12NUM, src14); + dst_c4 += stride; + + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + __m128 src15 = _mm_loadu_ps(src + src_stride * C3NUM + C8NUM); + __m128 src16 = _mm_loadu_ps(src + src_stride * C3NUM + C12NUM); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + src15 = _mm_add_ps(src15, bias4); + src16 = _mm_add_ps(src16, bias4); + + ActBlock8(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + _mm_storeu_ps(dst_c4 + C12NUM, src15); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + _mm_storeu_ps(dst_c4 + C12NUM, src16); + dst_c4 += stride; + src += C16NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src4 = _mm_loadu_ps(src + src_stride * C3NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + src4 = _mm_add_ps(src4, bias4); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + _mm_storeu_ps(dst_c4 + C12NUM, src4); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c4 <= (int)(oc4div)-C12NUM; loop_c4 += C12NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias += C12NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + + ActBlock12(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, relu_type == 1, + relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src3, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + + for (; loop_c4 <= (int)(oc4div)-C8NUM; loop_c4 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c4 < (int)(oc4div); loop_c4 += C4NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + } + if (oc4mod == 0) { + return; + } + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + float *dst_c1 = dst + oc4div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src += C4NUM; + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + switch (oc4mod) { + case 1: + _mm_store_ss(dst_c1, src1); + dst_c1 += stride; + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + dst_c1 += stride; + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + C2NUM, src1); + dst_c1 += stride; + break; + case C4NUM: + _mm_storeu_ps(dst_c1, src1); + dst_c1 += stride; + break; + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c new file mode 100644 index 00000000..168a5273 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/WinogradTrans.c @@ -0,0 +1,376 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#include "nnacl_c/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4; + size_t S_step = length * w * 4; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 7 * S_step - len_c4; + } + for (; k_tmp >= 4; k_tmp -= 4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C8NUM, M += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(SK + S_step); + __m128 s22 = _mm_loadu_ps(SK + S_step + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + __m128 s33 = _mm_loadu_ps(SK + 2 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + __m128 s44 = _mm_loadu_ps(SK + 3 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 4 * S_step - len_c4; + } + for (; k_tmp >= 3; k_tmp -= 3) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 3 * S_step - len_c4; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + M -= len_c4; + SK += S_step - len_c4; + } + SW += len_c4; + M += len_c4; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4, k_step = len_c4 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + const float *S2 = SK + len_c4, *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4, *S5 = S4 + len_c4; + const float *S6 = S5 + len_c4, *S7 = S6 + len_c4; + for (int len_tmp = length; len_tmp > 0; + --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(S6); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(S6); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S7; + } + for (; k_tmp >= 4; k_tmp -= 4, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(S2); + __m128 s22 = _mm_loadu_ps(S2 + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + __m128 s33 = _mm_loadu_ps(S3 + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(S4); + __m128 s44 = _mm_loadu_ps(S4 + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S4; + } + for (; k_tmp >= 3; k_tmp -= 3, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s0, k1, M1); + __m128 s1 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s1, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h new file mode 100644 index 00000000..5885954a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/intrinsics/sse/sse_common.h @@ -0,0 +1,390 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ +#define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ + +#define SSE_ROW_NUM_1 1 +#define SSE_ROW_NUM_2 2 +#define SSE_ROW_NUM_3 3 + +#define SSE_INDEX_1 1 +#define SSE_INDEX_2 2 +#define SSE_INDEX_3 3 +#define SSE_INDEX_4 4 +#define SSE_INDEX_5 5 +#define SSE_INDEX_6 6 + +#define SSE_SHUFFLE_0321 (_MM_SHUFFLE(0, 3, 2, 1)) + +#define SSE_ACT_RELU 1 +#define SSE_ACT_RELU6 3 + +static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock12(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, __m128 *v9, __m128 *v10, __m128 *v11, __m128 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m128 zero_ma = _mm_setzero_ps(); + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + *v5 = _mm_max_ps(zero_ma, *v5); + *v6 = _mm_max_ps(zero_ma, *v6); + *v7 = _mm_max_ps(zero_ma, *v7); + *v8 = _mm_max_ps(zero_ma, *v8); + *v9 = _mm_max_ps(zero_ma, *v9); + *v10 = _mm_max_ps(zero_ma, *v10); + *v11 = _mm_max_ps(zero_ma, *v11); + *v12 = _mm_max_ps(zero_ma, *v12); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + *v5 = _mm_min_ps(relu6_ma, *v5); + *v6 = _mm_min_ps(relu6_ma, *v6); + *v7 = _mm_min_ps(relu6_ma, *v7); + *v8 = _mm_min_ps(relu6_ma, *v8); + *v9 = _mm_min_ps(relu6_ma, *v9); + *v10 = _mm_min_ps(relu6_ma, *v10); + *v11 = _mm_min_ps(relu6_ma, *v11); + *v12 = _mm_min_ps(relu6_ma, *v12); + } +} + +static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, size_t relu_type) { + __m128 relu6 = _mm_set_ps1(6.0); + __m128 zero = _mm_setzero_ps(); + switch (relu_type) { + case SSE_ACT_RELU6: + *v1 = _mm_min_ps(*v1, relu6); + *v2 = _mm_min_ps(*v2, relu6); + *v3 = _mm_min_ps(*v3, relu6); + *v4 = _mm_min_ps(*v4, relu6); + *v5 = _mm_min_ps(*v5, relu6); + *v6 = _mm_min_ps(*v6, relu6); + *v7 = _mm_min_ps(*v7, relu6); + *v8 = _mm_min_ps(*v8, relu6); + case SSE_ACT_RELU: + *v1 = _mm_max_ps(*v1, zero); + *v2 = _mm_max_ps(*v2, zero); + *v3 = _mm_max_ps(*v3, zero); + *v4 = _mm_max_ps(*v4, zero); + *v5 = _mm_max_ps(*v5, zero); + *v6 = _mm_max_ps(*v6, zero); + *v7 = _mm_max_ps(*v7, zero); + *v8 = _mm_max_ps(*v8, zero); + default: + break; + } +} + +static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst7); + } +} + +static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst += stride; + *dst += SSE_INDEX_2; + } +} + +static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + __m128 *dst5, __m128 *dst6, __m128 *dst7, __m128 *dst8) { + __m128 bias1 = _mm_loadu_ps(bias_ptr); + __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM); + *dst1 = _mm_add_ps(*dst1, bias1); + *dst2 = _mm_add_ps(*dst2, bias2); + *dst3 = _mm_add_ps(*dst3, bias1); + *dst4 = _mm_add_ps(*dst4, bias2); + *dst5 = _mm_add_ps(*dst5, bias1); + *dst6 = _mm_add_ps(*dst6, bias2); + *dst7 = _mm_add_ps(*dst7, bias1); + *dst8 = _mm_add_ps(*dst8, bias2); +} + +#endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c new file mode 100644 index 00000000..19a32c48 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/init_exec_env.h" + +static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16]; + +void RegKernelCreator(int opType, int dataType, KernelCreator creator) { + g_kernelCreatorRegistry[opType][REGIST_DT(dataType)] = creator; +} + +void Init_MSC_VER_kernels(void) { +#ifdef _MSC_VER + /* VS env do not support automatic register + * register here first time */ + static bool inited = false; + if (inited == false) { + init_vs_kernels(g_kernelCreatorRegistry); + inited = true; + } +#endif + return; +} + +bool checkOpValid(int opType) { + if (opType < PrimType_MIN || opType >= PrimType_MAX) { + return false; + } + return true; +} + +bool SupportKernelC(int opType, int dataType) { + Init_MSC_VER_kernels(); + const int length = 16; + if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) { + return false; + } + if (!checkOpValid(opType)) { + return false; + } + KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)]; + return creator != NULL; +} + +int DefaultThreadUpdate(int32_t type, int64_t load, int64_t store, int64_t unit, int thread) { + return thread > 0 ? thread : 1; +} + +int NNACLKernelInferShape(struct KernelBase *self) { return NNACL_ERR; } + +int NNACLCheckKernelBase(KernelBase *kernel_base) { + CheckExecEnv(kernel_base); + + if (kernel_base->param_ == NULL) { + return NNACL_ERR; + } + + if (kernel_base->thread_nr_ <= 0 || kernel_base->thread_nr_ > MAX_THREAD_NUM) { + return NNACL_ERR; + } + + if (kernel_base->in_size_ == 0 || kernel_base->in_ == NULL) { + return NNACL_ERR; + } + if (kernel_base->out_size_ == 0 || kernel_base->out_ == NULL) { + return NNACL_ERR; + } + return NNACL_OK; +} + +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env) { + Init_MSC_VER_kernels(); + if (param == NULL) { + return NULL; + } + if (!checkOpValid(param->type_)) { + return NULL; + } + + KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)]; + if (creator == NULL) { + return NULL; + } + + KernelBase *kernel_base = creator(param, data_type); + if (kernel_base == NULL) { + return NULL; + } + + kernel_base->InferShape = NNACLKernelInferShape; + kernel_base->UpdateThread = DefaultThreadUpdate; + kernel_base->env_ = env; + kernel_base->param_ = param; + kernel_base->thread_nr_ = param->thread_num_; + kernel_base->train_session_ = param->is_train_session_; + kernel_base->in_ = ins; + kernel_base->in_size_ = in_size; + kernel_base->out_ = outs; + kernel_base->out_size_ = out_size; + + int ret = NNACLCheckKernelBase(kernel_base); + if (ret != NNACL_OK) { + return NULL; + } + + return kernel_base; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h new file mode 100644 index 00000000..84378c7f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_H_ +#define NNACL_KERNEL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/infer/common_infer.h" + +typedef struct ExecEnv { + void *allocator_; + void *thread_pool_; + void *(*Alloc)(void *allocator, size_t sz); + void (*Free)(void *allocator, void *ptr); + int (*ParallelLaunch)(void *thread_pool, void *task, void *param, int task_num); +} ExecEnv; + +typedef struct KernelBase { + int (*Release)(struct KernelBase *self); + int (*Prepare)(struct KernelBase *self); + int (*Compute)(struct KernelBase *self); + int (*Resize)(struct KernelBase *self); + int (*InferShape)(struct KernelBase *self); + int (*UpdateThread)(int32_t type, int64_t load, int64_t store, int64_t unit, int thread); + OpParameter *param_; + int thread_nr_; + ExecEnv *env_; + TensorC **in_; + size_t in_size_; + TensorC **out_; + size_t out_size_; + bool train_session_; + void *workspace_; /* only used in train */ + int work_size_; /* only used in train */ +} KernelBase; + +#ifdef _MSC_VER +#define REG_KERNEL_CREATOR(op, data_type, func) +#else +#define REG_KERNEL_CREATOR(op, data_type, func) \ + __attribute__((constructor(102))) void Reg##op##data_type##Creator() { RegKernelCreator(op, data_type, func); } +#endif + +#define REGIST_DT(DataType) (DataType - kNumberTypeBegin - 1) +typedef KernelBase *(*KernelCreator)(OpParameter *param, int data_type); +void RegKernelCreator(int opType, int dataType, KernelCreator func); + +#ifdef __cplusplus +extern "C" { +#endif +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env); +bool SupportKernelC(int opType, int dataType); +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c new file mode 100644 index 00000000..50cb1389 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.c @@ -0,0 +1,194 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/activation.h" +#include "nnacl_c/activation_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/activation_fp16.h" +#endif + +typedef struct ActivationStruct { + KernelBase base; + int data_type_; + ActType act_type_; +} ActivationStruct; + +int ActivationResize(struct KernelBase *self) { + ActivationStruct *activation = (ActivationStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(activation); + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_Activation, activation->act_type_), 1, 1, + NNACLGetElementNum(self->out_[0]), self->thread_nr_); + return NNACL_OK; +} + +int activation_fp32_run(ActivationStruct *activation, int task_id, int count, int stride) { + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float *input = activation->base.in_[0]->data_; + float *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Fp32Relu(input + task_id * stride, count, output + task_id * stride); + case ActType_Relu6: + return Fp32Relu6(input + task_id * stride, count, output + task_id * stride); + case ActType_LeakyRelu: + return LRelu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + case ActType_Sigmoid: + return Sigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_Tanh: + return Tanh(input + task_id * stride, count, output + task_id * stride); + case ActType_Swish: + return Swish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSwish: + return HSwish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSigmoid: + return HSigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_HardTanh: + return HardTanh(input + task_id * stride, count, output + task_id * stride, param->min_val_, param->max_val_); + case ActType_Gelu: + return Gelu(input + task_id * stride, count, output + task_id * stride, param->approximate_); + case ActType_Softplus: + return Softplus(input + task_id * stride, count, output + task_id * stride); + case ActType_Elu: + return Elu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_int32_run(ActivationStruct *activation, int task_id, int count, int stride) { + int32_t *input = activation->base.in_[0]->data_; + int32_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Int32Relu(input + task_id * stride, count, output + task_id * stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_fp16_run(ActivationStruct *activation, int task_id, int count, int stride) { +#ifdef ENABLE_FP16 + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float16_t *input = activation->base.in_[0]->data_; + float16_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return ReluFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Relu6: + return Relu6Fp16(input + stride * task_id, output + stride * task_id, count); + case ActType_LeakyRelu: + return LReluFp16(input + stride * task_id, output + stride * task_id, count, param->alpha_); + case ActType_Sigmoid: + return SigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Tanh: + return TanhFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSwish: + return HSwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Swish: + return SwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSigmoid: + return HSigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HardTanh: + return HardTanhFp16(input + stride * task_id, count, output + stride * task_id, param->min_val_, param->max_val_); + case ActType_Gelu: + return GeluFp16(input + stride * task_id, count, output + stride * task_id, true); + case ActType_Softplus: + return SoftplusFp16(input + stride * task_id, count, output + stride * task_id); + case ActType_Elu: + return EluFp16(input + stride * task_id, count, output + stride * task_id, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +#endif + return NNACL_DISABLE_FP16; +} + +int ActivationImpl(void *cdata, int task_id, float l, float r) { + ActivationStruct *activation = (ActivationStruct *)cdata; + + int ele_num = NNACLGetElementNum(activation->base.in_[0]); + NNACL_CHECK_ZERO_RETURN_ERR(activation->base.thread_nr_); + int stride = UP_DIV(ele_num, activation->base.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int count = MSMIN(stride, ele_num - stride * task_id); + if (count <= 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(stride, task_id), NNACL_ERR); + + switch (activation->data_type_) { + case kNumberTypeFloat32: + return activation_fp32_run(activation, task_id, count, stride); + case kNumberTypeFloat16: + return activation_fp16_run(activation, task_id, count, stride); + case kNumberTypeInt32: + return activation_int32_run(activation, task_id, count, stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int ActivationCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ActivationImpl, self, self->thread_nr_); +} + +KernelBase *CreateActivation(OpParameter *param, int data_type) { + ActivationParameter *act = (ActivationParameter *)(param); + + int type = act->type_; + if (data_type == kNumberTypeInt32) { + if (type != ActType_Relu) { + return NULL; + } + } + + if (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat16) { + if (type != ActType_Relu && type != ActType_Relu6 && type != ActType_LeakyRelu && type != ActType_Sigmoid && + type != ActType_Tanh && type != ActType_HSwish && type != ActType_Swish && type != ActType_HardTanh && + type != ActType_Gelu && type != ActType_HSigmoid && type != ActType_Softplus && type != ActType_Elu) { + return NULL; + } + } + + ActivationStruct *activation = (ActivationStruct *)malloc(sizeof(ActivationStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(activation); + memset(activation, 0, sizeof(ActivationStruct)); + + activation->data_type_ = data_type; + activation->act_type_ = act->type_; + activation->base.Prepare = DefaultPrepare1In1Out; + activation->base.Release = DefaultRelease; + activation->base.Resize = ActivationResize; + activation->base.Compute = ActivationCompute; + return (KernelBase *)activation; +} + +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat32, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat16, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeUInt32, CreateActivation) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h new file mode 100644 index 00000000..b7b06c01 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/activation.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_ACTIVATION_H_ +#define NNACL_KERNEL_ACTIVATION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateActivation(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ACTIVATION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c new file mode 100644 index 00000000..3916981a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.c @@ -0,0 +1,144 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/addn.h" +#include "nnacl_c/fp32/add_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_fp16.h" +#endif + +int AddNLaunch(void *cdata, int task_id, float l, float r) { + AddNStruct *addn = (AddNStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + int count_per_thread = UP_DIV(addn->elements_num_, addn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, count_per_thread, NNACL_ERR); + int count = MSMIN(count_per_thread, addn->elements_num_ - task_id * count_per_thread); + int stride = count_per_thread * task_id; + +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + return ElementAddFp16((float16_t *)addn->in1_addr_ + stride, (float16_t *)addn->in2_addr_ + stride, + (float16_t *)addn->out_addr_ + stride, count); + } +#endif + return ElementAdd((float *)addn->in1_addr_ + stride, (float *)addn->in2_addr_ + stride, + (float *)addn->out_addr_ + stride, count); +} + +void AddNCompute(AddNStruct *addn, bool same_shape, bool first_scalar) { +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + if (same_shape) { + ElementAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_); + } else { + ElementOptAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_, first_scalar); + } + return; + } +#endif + + if (same_shape) { + ElementAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_); + } else { + ElementOptAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_, + first_scalar); + } + return; +} + +int AddNComputeNoParallel(AddNStruct *addn) { + TensorC *in0_tensor = addn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in0_tensor); + TensorC *in1_tensor = addn->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in1_tensor); + AddNCompute(addn, NNACLIsShapeSame(in0_tensor, in1_tensor), NNACLGetElementNum(in0_tensor) == 1); + + for (size_t i = Index2; i < addn->base_.in_size_; i++) { + TensorC *in_tensor = addn->base_.in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + addn->in1_addr_ = in_tensor->data_; + addn->in2_addr_ = addn->out_addr_; + AddNCompute(addn, NNACLIsShapeSame(in_tensor, addn->base_.out_[OUTPUT_INDEX]), NNACLGetElementNum(in_tensor) == 1); + } + return NNACL_OK; +} + +int AddnResize(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + addn->elements_num_ = NNACLGetElementNum(out_tensor); + return NNACL_OK; +} + +int AddnCompute(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + addn->in1_addr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in2_addr_); + addn->out_addr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->out_addr_); + + if (addn->elements_num_ < self->thread_nr_) { + return AddNComputeNoParallel(addn); + } + + for (int i = 0; i < self->in_size_; i++) { + TensorC *in_tensor = self->in_[i]; + if (!NNACLIsShapeSame(in_tensor, self->out_[OUTPUT_INDEX])) { + return NNACL_ADDN_SHAPE_UNMATCH; + } + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + for (size_t i = Index2; i < self->in_size_; ++i) { + addn->in1_addr_ = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = addn->out_addr_; + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +KernelBase *CreateAddN(OpParameter *param, int data_type) { + AddNStruct *addn = (AddNStruct *)malloc(sizeof(AddNStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(addn); + addn->data_type_ = data_type; + addn->base_.Prepare = DefaultPrepare1In1Out; + addn->base_.Resize = AddnResize; + addn->base_.Release = DefaultRelease; + addn->base_.Compute = AddnCompute; + return (KernelBase *)addn; +} + +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat16, CreateAddN) +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat32, CreateAddN) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h new file mode 100644 index 00000000..90430088 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/addn.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ADDN_H_ +#define NNACL_KERNEL_ADDN_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct AddNStruct { + KernelBase base_; + int data_type_; + int elements_num_; + void *in1_addr_; + void *in2_addr_; + void *out_addr_; +} AddNStruct; + +KernelBase *CreateAddN(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ADDN_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c new file mode 100644 index 00000000..2bd0dd99 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.c @@ -0,0 +1,127 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arg_min_max.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/arg_min_max_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arg_min_max_fp16.h" +#endif + +int ArgMinMaxPrepare(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + arg_min_max->arg_elements_alloc_ = param->topk_ > Num1 || param->keep_dims_; + arg_min_max->compute_.topk_ = param->topk_; + arg_min_max->compute_.axis_ = param->axis_; + arg_min_max->compute_.keep_dims_ = param->keep_dims_; + arg_min_max->compute_.out_value_ = param->out_value_; + arg_min_max->compute_.get_max_ = self->param_->type_ == PrimType_ArgMinFusion ? false : true; + return NNACL_OK; +} + +int ArgMinMaxResize(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxComputeParam *compute = &arg_min_max->compute_; + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + ComputeStrides(input_tensor->shape_, compute->in_strides_, input_tensor->shape_size_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ComputeStrides(output_tensor->shape_, compute->out_strides_, output_tensor->shape_size_); + + compute->dims_size_ = (int)input_tensor->shape_size_; + compute->axis_ = compute->axis_ < 0 ? compute->axis_ + compute->dims_size_ : compute->axis_; + NNACL_CHECK_FALSE(compute->topk_ <= 0, NNACL_ARG_MIN_MAX_AXIS_INVALID); + NNACL_CHECK_FALSE(compute->topk_ > input_tensor->shape_[compute->axis_], NNACL_ARG_MIN_MAX_AXIS_INVALID); + return NNACL_OK; +} + +int ArgMinMaxCompute(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + void *in_data = in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + void *out_data = out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + void *out_value = NULL; + if (self->out_size_ == TWO_TENSOR) { + out_value = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_value); + } + + if (arg_min_max->arg_elements_alloc_) { + int arg_size = in_tensor->shape_[arg_min_max->compute_.axis_] * (int)sizeof(ArgElement); + NNACL_CHECK_MALLOC_SIZE(arg_size); + arg_min_max->compute_.arg_elements_ = (ArgElement *)self->env_->Alloc(self->env_->allocator_, arg_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arg_min_max->compute_.arg_elements_); + } + + int ret = NNACL_OK; + int *in_shape = in_tensor->shape_; + if (in_tensor->data_type_ == kNumberTypeFloat32) { + ArgMinMaxFp32((float *)in_data, out_data, (float *)out_value, in_shape, &arg_min_max->compute_); +#ifdef ENABLE_FP16 + } else if (in_tensor->data_type_ == kNumberTypeFloat16) { + ArgMinMaxFp16((float16_t *)in_data, out_data, (float16_t *)out_value, in_shape, &arg_min_max->compute_); +#endif + } else if (in_tensor->data_type_ == kNumberTypeInt32) { + ArgMinMaxInt32((int32_t *)in_data, out_data, (int32_t *)out_value, in_shape, &arg_min_max->compute_); + } else { + ret = NNACL_UNSUPPORTED_DATA_TYPE; + } + + if (arg_min_max->arg_elements_alloc_) { + self->env_->Free(self->env_->allocator_, arg_min_max->compute_.arg_elements_); + arg_min_max->compute_.arg_elements_ = NULL; + } + return ret; +} + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)malloc(sizeof(ArgMinMaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arg_min_max); + memset(arg_min_max, 0, sizeof(ArgMinMaxStruct)); + + arg_min_max->base_.Prepare = ArgMinMaxPrepare; + arg_min_max->base_.Resize = ArgMinMaxResize; + arg_min_max->base_.Release = DefaultRelease; + arg_min_max->base_.Compute = ArgMinMaxCompute; + return (KernelBase *)arg_min_max; +} + +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat32, CreateArgMinMax) + +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat32, CreateArgMinMax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h new file mode 100644 index 00000000..52c7f7f3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arg_min_max.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARG_MIN_MAX_H_ +#define NNACL_KERNEL_ARG_MIN_MAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#ifdef ENABLE_ARM64 +#include +#endif + +typedef struct ArgElement { + uint32_t index_; + union ArgData { + int8_t i8_data_; + int32_t i_data_; + float f_data_; +#ifdef ENABLE_ARM +#ifdef ENABLE_FP16 + float16_t f16_data_; +#endif +#endif + } data_; +} ArgElement; + +typedef int (*COMPARE_FUNCTION)(const void *a, const void *b); + +typedef struct ArgMinMaxComputeParam { + int32_t axis_; + int32_t dims_size_; + int32_t topk_; + bool get_max_; + bool keep_dims_; + bool out_value_; + int32_t in_strides_[COMM_SHAPE_SIZE]; + int32_t out_strides_[COMM_SHAPE_SIZE]; + ArgElement *arg_elements_; +} ArgMinMaxComputeParam; + +typedef struct ArgMinMaxStruct { + KernelBase base_; + ArgMinMaxComputeParam compute_; + bool arg_elements_alloc_; +} ArgMinMaxStruct; + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARG_MIN_MAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c new file mode 100644 index 00000000..973efbf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.c @@ -0,0 +1,725 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either arithmeticress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/mul_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_fp16.h" +#endif + +void InitArithmeticRunFunction(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + ArithmeticFuncions fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulRelu, ElementMulReluInt, NULL, ElementOptMulRelu, ElementOptMulReluInt, + NULL}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6, ElementMulRelu6Int, NULL, ElementOptMulRelu6, + ElementOptMulRelu6Int, NULL}, + {PrimType_MulFusion, ActType_No, ElementMul, ElementMulInt, NULL, ElementOptMul, ElementOptMulInt, NULL}, + {PrimType_AddFusion, ActType_Relu, ElementAddRelu, NULL, NULL, ElementOptAddRelu, NULL, NULL}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6, NULL, NULL, ElementOptAddRelu6, NULL, NULL}, + {PrimType_AddFusion, ActType_No, ElementAdd, ElementAddInt, NULL, ElementOptAdd, ElementOptAddInt, NULL}, + {PrimType_SubFusion, ActType_Relu, ElementSubRelu, NULL, NULL, ElementOptSubRelu, NULL, NULL}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6, NULL, NULL, ElementOptSubRelu6, NULL, NULL}, + {PrimType_SubFusion, ActType_No, ElementSub, ElementSubInt, NULL, ElementOptSub, ElementOptSubInt, NULL}, + {PrimType_DivFusion, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_DivFusion, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_RealDiv, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_RealDiv, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAnd, ElementLogicalAndInt, ElementLogicalAndBool, + ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOr, NULL, ElementLogicalOrBool, NULL, NULL, ElementOptLogicalOrBool}, + {PrimType_Maximum, ActType_No, ElementMaximum, ElementMaximumInt, NULL, ElementOptMaximum, ElementOptMaximumInt, + NULL}, + {PrimType_Minimum, ActType_No, ElementMinimum, ElementMinimumInt, NULL, ElementOptMinimum, ElementOptMinimumInt, + NULL}, + {PrimType_FloorMod, ActType_No, ElementFloorMod, ElementFloorModInt, NULL, ElementOptFloorMod, + ElementOptFloorModInt, NULL}, + {PrimType_FloorDiv, ActType_No, ElementFloorDiv, ElementFloorDivInt, NULL, ElementOptFloorDiv, + ElementOptFloorDivInt, NULL}, + {PrimType_Mod, ActType_No, ElementMod, ElementModInt, NULL, ElementOptMod, ElementOptModInt, NULL}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifference, NULL, NULL, ElementOptSquaredDifference, NULL, + NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic->primitive_type_ && + fun_table[i].activation_type_ == ((ArithmeticParameter *)(arithmetic->base_.param_))->activation_type_) { + arithmetic->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticRelease(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + for (int i = 0; i < TWO_TENSOR; i++) { + if (arithmetic->broadcast_buffer_[i] != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->broadcast_buffer_[i]); + arithmetic->broadcast_buffer_[i] = NULL; + } + } + + for (int i = 0; i < arithmetic->block_boundary_infos_size_; i++) { + if (arithmetic->block_boundary_infos_[i].a_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].a_offset_); + arithmetic->block_boundary_infos_[i].a_offset_ = NULL; + } + if (arithmetic->block_boundary_infos_[i].b_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].b_offset_); + arithmetic->block_boundary_infos_[i].b_offset_ = NULL; + } + } + arithmetic->block_boundary_infos_size_ = 0; + + if (arithmetic->a_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->a_matrix_.batch_post_sum_); + arithmetic->a_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->b_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->b_matrix_.batch_post_sum_); + arithmetic->b_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->c_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->c_matrix_.batch_post_sum_); + arithmetic->c_matrix_.batch_post_sum_ = NULL; + } + return NNACL_OK; +} + +void ArithmeticComputeOffset(ArithmeticStruct *arithmetic, int task_id) { + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + block_info->init_offset_ = true; + + int64_t b_start = block_info->batch_begin_; + int64_t b_end = block_info->batch_end_; + int64_t s_end = block_info->size_end_; + if (s_end != 0) { + ++b_end; + } + int offset_index = 0; + for (; b_start < b_end; ++b_start) { + int64_t delta = b_start; + int64_t a_offset = 0; + int64_t b_offset = 0; + for (int j = 0; j <= arithmetic->batch_tail_dim_; ++j) { + if (j > 0) { + delta = delta % arithmetic->c_matrix_.batch_post_sum_[j]; + } + if (j < arithmetic->batch_tail_dim_) { + a_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->a_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->a_matrix_.batch_post_sum_[j + 1]; + b_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->b_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->b_matrix_.batch_post_sum_[j + 1]; + } else { + a_offset += (delta * arithmetic->a_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + b_offset += (delta * arithmetic->b_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + } + } + block_info->a_offset_[offset_index] = a_offset * arithmetic->a_matrix_.inner_size_ * arithmetic->in_data_size_; + block_info->b_offset_[offset_index] = b_offset * arithmetic->b_matrix_.inner_size_ * arithmetic->in_data_size_; + offset_index++; + } +} + +int ArithmeticDoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)base; + int data_type = arithmetic->base_.in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + if (data_type == kNumberTypeFloat32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_f32_); + return arithmetic->functions_.optimzie_f32_((const float *)input0, (const float *)input1, (float *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_f32_); + return arithmetic->functions_.compute_f32_((const float *)input0, (const float *)input1, (float *)output, size); + } + } + + if (data_type == kNumberTypeBool) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_bool_); + return arithmetic->functions_.optimzie_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_bool_); + return arithmetic->functions_.compute_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size); + } + } + + if (data_type == kNumberTypeInt32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_int_); + return arithmetic->functions_.optimzie_int_((const int *)input0, (const int *)input1, (int *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_int_); + return arithmetic->functions_.compute_int_((const int *)input0, (const int *)input1, (int *)output, size); + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticRun(void *cdata, int task_id, float l, float r) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)cdata; + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= arithmetic->block_boundary_infos_size_, NNACL_ERR); + + if (arithmetic->block_boundary_infos_[task_id].init_offset_ == false) { + ArithmeticComputeOffset(arithmetic, task_id); + } + + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + int64_t b_start = block_info->batch_begin_; + int64_t s_start = block_info->size_begin_; + int64_t s_end = block_info->size_end_; + int64_t index_start = 0; + int64_t index_end = block_info->batch_end_ - b_start; + uint8_t *a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + uint8_t *b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + uint8_t *c_ptr = (uint8_t *)(arithmetic->c_matrix_.data_) + + (b_start * arithmetic->c_matrix_.inner_size_ + s_start) * arithmetic->out_data_size_; + if (arithmetic->a_matrix_.inner_size_ > 1) { + a_ptr += s_start * arithmetic->in_data_size_; + } + if (arithmetic->b_matrix_.inner_size_ > 1) { + b_ptr += s_start * arithmetic->in_data_size_; + } + + if (index_start == index_end) { + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end - s_start); + } + + int64_t size = arithmetic->c_matrix_.inner_size_ - s_start; + int ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, size); + if (ret != NNACL_OK) { + return ret; + } + + ++index_start; + c_ptr += size * arithmetic->out_data_size_; + int64_t c_stride = arithmetic->c_matrix_.inner_size_ * arithmetic->out_data_size_; + for (; index_start < index_end; ++index_start) { + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, arithmetic->c_matrix_.inner_size_); + if (ret != NNACL_OK) { + return ret; + } + c_ptr += c_stride; + } + if (s_end == 0) { + return NNACL_OK; + } + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end); +} + +void ResetArithmeticMatric(KernelBase *base, ArithmeticMatrixInfo *matrix) { + matrix->is_valid_ = false; + matrix->data_ = NULL; + matrix->inner_size_ = 1; + matrix->shape_size_ = 0; + + if (matrix->batch_post_sum_ != NULL) { + base->env_->Free(base->env_->allocator_, matrix->batch_post_sum_); + matrix->batch_post_sum_ = NULL; + } +} + +int UpdateArithmeticParameter(ArithmeticStruct *arithmetic) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_size_ == arithmetic->b_matrix_.shape_size_, + NNACL_ARITHMETIC_SHAPE_INVALID); + + arithmetic->ndim_ = arithmetic->a_matrix_.shape_size_; + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + NNACL_CHECK_TRUE_RET(arithmetic->b_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + arithmetic->in_shape0_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->in_shape1_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->out_shape_[i] = MSMAX(arithmetic->in_shape0_[i], arithmetic->in_shape1_[i]); + arithmetic->c_matrix_.shape_[arithmetic->c_matrix_.shape_size_++] = + MSMAX(arithmetic->a_matrix_.shape_[i], arithmetic->b_matrix_.shape_[i]); + } + return NNACL_OK; +} + +int OptimizeArithmeticShape(ArithmeticStruct *arithmetic) { + ArithmeticMatrixInfo *a = &arithmetic->a_matrix_; + ArithmeticMatrixInfo *b = &arithmetic->b_matrix_; + arithmetic->ndim_ = a->shape_size_ >= b->shape_size_ ? a->shape_size_ : b->shape_size_; + + int shape0[MAX_LEN] = {0}; + int shape1[MAX_LEN] = {0}; + /* init a & b shape */ + int i = 0; + for (; i < arithmetic->ndim_; ++i) { + shape0[i] = 1; + shape1[i] = 1; + } + + /* init matrix shape dim */ + int a_matrix_size = arithmetic->ndim_ - a->shape_size_; + for (i = a_matrix_size; i < arithmetic->ndim_; i++) { + shape0[i] = a->shape_[i - a_matrix_size]; + } + + int b_matrix_size = arithmetic->ndim_ - b->shape_size_; + for (i = b_matrix_size; i < arithmetic->ndim_; i++) { + shape1[i] = b->shape_[i - b_matrix_size]; + } + + /* horizontal shape dims */ + int shape0_temp[MAX_LEN] = {0}; + int shape1_temp[MAX_LEN] = {0}; + int shape_temp_size = 0; + for (i = 0; i < arithmetic->ndim_;) { // horizontal comparison, merge the part of continuous 1. + shape0_temp[shape_temp_size] = shape0[i]; + shape1_temp[shape_temp_size] = shape1[i]; + shape_temp_size++; + if (shape0[i] != 1 && shape1[i] != 1) { + ++i; + continue; + } + + size_t j0 = i; + while (j0 < arithmetic->ndim_ && shape0[j0] == 1) { + ++j0; + } + size_t j1 = i; + while (j1 < arithmetic->ndim_ && shape1[j1] == 1) { + ++j1; + } + size_t j = MSMAX(j0, j1); + while ((++i) < j) { + shape0_temp[shape_temp_size - 1] *= shape0[i]; + shape1_temp[shape_temp_size - 1] *= shape1[i]; + } + } + + arithmetic->a_matrix_.shape_size_ = 0; + arithmetic->b_matrix_.shape_size_ = 0; + + for (i = 0; i < shape_temp_size;) { // vertical comparison, merge the part of continuous equation. + if (shape0_temp[i] == 1 && shape1_temp[i] == 1) { + ++i; + continue; + } + shape0[arithmetic->a_matrix_.shape_size_++] = shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_++] = shape1_temp[i]; + if (shape0_temp[i] != shape1_temp[i]) { + ++i; + continue; + } + while ((++i) < shape_temp_size) { + if (shape0_temp[i] != shape1_temp[i]) { + break; + } + shape0[arithmetic->a_matrix_.shape_size_ - 1] *= shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_ - 1] *= shape1_temp[i]; + } + } + + memcpy(arithmetic->a_matrix_.shape_, shape0, arithmetic->a_matrix_.shape_size_ * sizeof(int)); + memcpy(arithmetic->b_matrix_.shape_, shape1, arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return UpdateArithmeticParameter(arithmetic); +} + +int ResetArithmeticStatus(ArithmeticStruct *arithmetic) { + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->a_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->b_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + arithmetic->a_matrix_.shape_size_ = arithmetic->base_.in_[FIRST_INPUT]->shape_size_; + memcpy(arithmetic->a_matrix_.shape_, arithmetic->base_.in_[FIRST_INPUT]->shape_, + arithmetic->a_matrix_.shape_size_ * sizeof(int)); + arithmetic->b_matrix_.shape_size_ = arithmetic->base_.in_[SECOND_INPUT]->shape_size_; + memcpy(arithmetic->b_matrix_.shape_, arithmetic->base_.in_[SECOND_INPUT]->shape_, + arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return OptimizeArithmeticShape(arithmetic); +} + +void ArithmeticDoBroadcast(ArithmeticStruct *arithmetic, void *in_data, void *out_data, int input_index) { + int *in_shape = input_index == FIRST_INPUT ? arithmetic->in_shape0_ : arithmetic->in_shape1_; + int *in_stride = input_index == FIRST_INPUT ? arithmetic->in_strides0_ : arithmetic->in_strides1_; + int *multiples = input_index == FIRST_INPUT ? arithmetic->multiples0_ : arithmetic->multiples1_; + return arithmetic->tile_function_(in_data, out_data, 0, arithmetic->ndim_, in_shape, in_stride, + arithmetic->out_strides_, multiples); +} + +int CheckDivDataInvalid(ArithmeticStruct *arithmetic) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]); + if ((arithmetic->primitive_type_ == PrimType_DivFusion || arithmetic->primitive_type_ == PrimType_RealDiv) && + arithmetic->base_.in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32) { + int element_num = NNACLGetElementNum(arithmetic->base_.in_[SECOND_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_); + int *int_data = (int *)(arithmetic->base_.in_[SECOND_INPUT]->data_); + for (int i = 0; i < element_num; i++) { + if (int_data[i] == 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + } + return NNACL_OK; +} + +int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + CalcStructMultiplesAndStrides(arithmetic); + +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE + bool prefer_explicit_broadcast = false; +#else + bool prefer_explicit_broadcast = arithmetic->ndim_ != 1; +#endif + prefer_explicit_broadcast = + prefer_explicit_broadcast && (arithmetic->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeBool); + + bool exist_broadcast_ = false; + int buffer_size = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]) * arithmetic->in_data_size_; + if (arithmetic->a_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[FIRST_INPUT]->data_); + if (arithmetic->in_elements_num0_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->a_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + arithmetic->broadcast_buffer_[Index0] = arithmetic->a_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[FIRST_INPUT]->data_, arithmetic->a_matrix_.data_, Index0); + arithmetic->in_elements_num0_ = arithmetic->out_elements_num_; + + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape0_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides0_[i] = arithmetic->out_strides_[i]; + } + memcpy(arithmetic->a_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->a_matrix_.is_valid_ = true; + } + } + + if (arithmetic->b_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_); + int ret = CheckDivDataInvalid(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + if (arithmetic->in_elements_num1_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->b_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + arithmetic->broadcast_buffer_[Index1] = arithmetic->b_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[Index1]->data_, arithmetic->b_matrix_.data_, Index1); + arithmetic->in_elements_num1_ = arithmetic->out_elements_num_; + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape1_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides1_[i] = arithmetic->out_strides_[i]; + } + + memcpy(arithmetic->b_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->b_matrix_.is_valid_ = true; + } + } + if (!exist_broadcast_) { + return NNACL_OK; + } + return OptimizeArithmeticShape(arithmetic); +} + +int ArithmeticComputeOfflineInfo(ArithmeticStruct *arithmetic) { + int bread_pos = -1; + int last_dim = arithmetic->a_matrix_.shape_size_ - 1; + for (int i = last_dim; i >= 0; --i) { + if (arithmetic->a_matrix_.shape_[i] != arithmetic->b_matrix_.shape_[i]) { + bread_pos = i; + break; + } + } + arithmetic->batch_tail_dim_ = bread_pos; + if (bread_pos == last_dim && arithmetic->batch_tail_dim_ >= 0) { + --arithmetic->batch_tail_dim_; + } + + for (int i = last_dim; i > arithmetic->batch_tail_dim_; --i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->a_matrix_.inner_size_, arithmetic->a_matrix_.shape_[i], NNACL_ERR); + arithmetic->a_matrix_.inner_size_ *= arithmetic->a_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->b_matrix_.inner_size_, arithmetic->b_matrix_.shape_[i], NNACL_ERR); + arithmetic->b_matrix_.inner_size_ *= arithmetic->b_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->c_matrix_.inner_size_, arithmetic->c_matrix_.shape_[i], NNACL_ERR); + arithmetic->c_matrix_.inner_size_ *= arithmetic->c_matrix_.shape_[i]; + } + + arithmetic->a_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->a_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->a_matrix_.shape_size_ + 1; i++) { + arithmetic->a_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->b_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->b_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->b_matrix_.shape_size_ + 1; i++) { + arithmetic->b_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->c_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->c_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->c_matrix_.shape_size_ + 1; i++) { + arithmetic->c_matrix_.batch_post_sum_[i] = 1; + } + + for (int i = arithmetic->batch_tail_dim_; i >= 0; --i) { + if (i == arithmetic->batch_tail_dim_) { + arithmetic->a_matrix_.batch_post_sum_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->b_matrix_.batch_post_sum_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->c_matrix_.batch_post_sum_[i] = arithmetic->c_matrix_.shape_[i]; + } else { + arithmetic->a_matrix_.batch_post_sum_[i] = + arithmetic->a_matrix_.shape_[i] * arithmetic->a_matrix_.batch_post_sum_[i + 1]; + arithmetic->b_matrix_.batch_post_sum_[i] = + arithmetic->b_matrix_.shape_[i] * arithmetic->b_matrix_.batch_post_sum_[i + 1]; + arithmetic->c_matrix_.batch_post_sum_[i] = + arithmetic->c_matrix_.shape_[i] * arithmetic->c_matrix_.batch_post_sum_[i + 1]; + } + } + + arithmetic->scalar_opt_ = false; + if (arithmetic->a_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num0_ = 1; + arithmetic->scalar_opt_ = true; + } + if (arithmetic->b_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num1_ = 1; + arithmetic->scalar_opt_ = true; + } + return NNACL_OK; +} + +int ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct *arithmetic) { + int total_num = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]); + arithmetic->base_.thread_nr_ = + arithmetic->base_.UpdateThread(TC_TYPE(arithmetic->primitive_type_, arithmetic->functions_.activation_type_), 1, 1, + total_num, arithmetic->base_.thread_nr_); + + int64_t block_size = UP_DIV(total_num, arithmetic->base_.thread_nr_); + int64_t split_point = 0; + while (split_point < total_num) { + int64_t start = split_point; + int64_t end = start + block_size; + if (end > total_num) { + end = total_num; + } + ArithmeticBlockBoundaryInfo block_boundary_info; + block_boundary_info.size_begin_ = start % arithmetic->c_matrix_.inner_size_; + block_boundary_info.size_end_ = end % arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_begin_ = start / arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_end_ = end / arithmetic->c_matrix_.inner_size_; + block_boundary_info.init_offset_ = false; + + int max_offset_size = block_boundary_info.batch_end_ - block_boundary_info.batch_begin_ + TWO_TENSOR; + block_boundary_info.a_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.a_offset_); + block_boundary_info.b_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.b_offset_); + + arithmetic->block_boundary_infos_[arithmetic->block_boundary_infos_size_++] = block_boundary_info; + split_point = end; + } + + arithmetic->base_.thread_nr_ = arithmetic->block_boundary_infos_size_; + return NNACL_OK; +} + +int ArithmeticResize(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + ArithmeticRelease(&arithmetic->base_); + + NNACL_CHECK_TRUE_RET(arithmetic->in_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_TRUE_RET(arithmetic->out_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + arithmetic->in_elements_num0_ = NNACLGetElementNum(self->in_[FIRST_INPUT]); + arithmetic->in_elements_num1_ = NNACLGetElementNum(self->in_[SECOND_INPUT]); + arithmetic->out_elements_num_ = NNACLGetElementNum(self->in_[OUTPUT_INDEX]); + + int ret = ResetArithmeticStatus(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticBroadCastConstTensor(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticComputeOfflineInfo(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + return ArithmeticChooseThreadCuttingStrategy(arithmetic); +} + +int ArithmeticPrepare(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ < kNumberTypeBegin, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + + if (self->param_->quant_type_ != Quant_None) { + return NNACL_ERR; + } + + arithmetic->primitive_type_ = self->param_->type_; + if (self->param_->type_ == PrimType_Eltwise) { + switch (((ArithmeticParameter *)(self->param_))->eltwise_mode_) { + case Eltwise_PROD: + arithmetic->primitive_type_ = PrimType_MulFusion; + break; + case Eltwise_SUM: + arithmetic->primitive_type_ = PrimType_AddFusion; + break; + case Eltwise_MAXIMUM: + arithmetic->primitive_type_ = PrimType_Maximum; + break; + default: + return NNACL_ELTWISE_INVALID_MOD; + } + } + arithmetic->init_function_(self); + + arithmetic->a_matrix_.is_const_ = NNACLIsConst(self->in_[FIRST_INPUT]); + arithmetic->b_matrix_.is_const_ = NNACLIsConst(self->in_[SECOND_INPUT]); + return NNACL_OK; +} + +int ArithmeticCompute(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != self->in_[SECOND_INPUT]->data_type_, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH); + + if (self->train_session_) { + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + } + + if (false == arithmetic->a_matrix_.is_valid_) { + arithmetic->a_matrix_.data_ = self->in_[FIRST_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + + if (!arithmetic->b_matrix_.is_const_) { + int ret = CheckDivDataInvalid(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + } + if (false == arithmetic->b_matrix_.is_valid_) { + arithmetic->b_matrix_.data_ = self->in_[SECOND_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + + arithmetic->c_matrix_.data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.data_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticRun, self, self->thread_nr_); +} + +KernelBase *CreateArithmetic(OpParameter *param, int data_type) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)malloc(sizeof(ArithmeticStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic); + memset(arithmetic, 0, sizeof(ArithmeticStruct)); + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticRunFunction; + arithmetic->execute_ = ArithmeticDoExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeInt32, CreateArithmetic) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h new file mode 100644 index 00000000..7d261b40 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic.h @@ -0,0 +1,97 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_ARITHMETIC_H_ +#define NNACL_KERNEL_ARITHMETIC_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/arithmetic_parameter.h" + +typedef struct ArithmeticFuncions { + int primitive_type_; + int activation_type_; + int (*compute_f32_)(const float *in1, const float *in2, float *out, int ele); + int (*compute_int_)(const int *in1, const int *in2, int *out, int ele); + int (*compute_bool_)(const bool *in1, const bool *in2, bool *out, int ele); + int (*optimzie_f32_)(const float *in1, const float *in2, float *out, int ele, bool scalar); + int (*optimzie_int_)(const int *in1, const int *in2, int *out, int ele, bool scalar); + int (*optimzie_bool_)(const bool *in1, const bool *in2, bool *out, int ele, bool scalar); +} ArithmeticFuncions; + +typedef struct ArithmeticMatrixInfo { + bool is_const_; + bool is_valid_; + void *data_; + int64_t inner_size_; + int shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int shape_size_; + int *batch_post_sum_; /* shape size + 1 */ +} ArithmeticMatrixInfo; + +typedef struct ArithmeticBlockBoundaryInfo { + int batch_begin_; + int batch_end_; + int size_begin_; // start-offset under the begin batch + int size_end_; // end-num under the ending batch + int *a_offset_; + int *b_offset_; + bool init_offset_; +} ArithmeticBlockBoundaryInfo; + +typedef struct ArithmeticStruct { + KernelBase base_; + bool scalar_opt_; + int primitive_type_; + int ndim_; + int in_data_size_; + int out_data_size_; + int batch_tail_dim_; + + ArithmeticMatrixInfo a_matrix_; + ArithmeticMatrixInfo b_matrix_; + ArithmeticMatrixInfo c_matrix_; + ArithmeticFuncions functions_; + + void *broadcast_buffer_[TWO_TENSOR]; + int block_boundary_infos_size_; + ArithmeticBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; + + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num1_; + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + void (*tile_function_)(const void *inPtr, void *outPtr, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); + int (*execute_)(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size); + void (*init_function_)(KernelBase *base); +} ArithmeticStruct; + +KernelBase *CreateArithmetic(OpParameter *param, int data_type); +int ArithmeticPrepare(struct KernelBase *self); +int ArithmeticRelease(struct KernelBase *self); +int ArithmeticCompute(struct KernelBase *self); +int ArithmeticResize(struct KernelBase *self); + +#endif // NNACL_KERNEL_ARITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c new file mode 100644 index 00000000..5fee6647 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic_compare.h" +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_compare_fp32.h" + +typedef struct ArithmeticCompareFuncions { + int primitive_type_; + int (*compute_f32_)(const float *input0, const float *input1, uint8_t *output, int element_size); + int (*compute_i32_)(const int *input0, const int *input1, uint8_t *output, int element_size); + int (*optimize_f32)(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); + int (*optimize_i32)(const int *input0, const int *input1, uint8_t *output, int element_size, bool first_scalar); + int (*compute_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); + int (*optimize_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + int (*compute_bool)(const bool *input0, const bool *input1, uint8_t *output, int element_size); +} ArithmeticCompareFuncions; + +typedef struct ArithmeticCompareStruct { + ArithmeticStruct arithmetic_; + ArithmeticCompareFuncions functions_; +} ArithmeticCompareStruct; + +void InitArithmeticCompareRunFunction(KernelBase *self) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(arithmetic_compare); + + ArithmeticCompareFuncions fun_table[] = { + {PrimType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32, NULL, NULL, + ElementEqualBool}, + {PrimType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32, ElementOptNotEqualInt32, + ElementNotEqualInt64, ElementOptNotEqualInt64, NULL}, + {PrimType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32, NULL, NULL, NULL}, + {PrimType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32, ElementOptLessEqualInt32, + NULL, NULL, NULL}, + {PrimType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32, NULL, + NULL, NULL}, + {PrimType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32, + ElementOptGreaterEqualInt32, NULL, NULL, NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticCompareFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic_compare->arithmetic_.primitive_type_) { + arithmetic_compare->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticCompareExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + int data_type = base->in_[FIRST_INPUT]->data_type_; + bool first_scalar = arithmetic_compare->arithmetic_.in_elements_num0_ == 1; + + if (data_type == kNumberTypeFloat32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_f32); + return arithmetic_compare->functions_.optimize_f32((const float *)input0, (const float *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_f32_); + return arithmetic_compare->functions_.compute_f32_((const float *)input0, (const float *)input1, + (uint8_t *)output, size); + } + } + + if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i32); + return arithmetic_compare->functions_.optimize_i32((const int *)input0, (const int *)input1, (uint8_t *)output, + size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i32_); + return arithmetic_compare->functions_.compute_i32_((const int *)input0, (const int *)input1, (uint8_t *)output, + size); + } + } + + if (data_type == kNumberTypeInt64) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i64); + return arithmetic_compare->functions_.optimize_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i64); + return arithmetic_compare->functions_.compute_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size); + } + } + if (data_type == kNumberTypeBool) { + if (!arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_bool); + return arithmetic_compare->functions_.compute_bool((const bool *)input0, (const bool *)input1, (uint8_t *)output, + size); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticCompareResize(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticResize(self); +} + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)malloc(sizeof(ArithmeticCompareStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_compare); + memset(arithmetic_compare, 0, sizeof(ArithmeticCompareStruct)); + + ArithmeticStruct *arithmetic = (ArithmeticStruct *)arithmetic_compare; + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticCompareRunFunction; + arithmetic->execute_ = ArithmeticCompareExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticCompareResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic_compare; +} + +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeBool, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt64, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeInt32, CreateArithmeticCompare) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h new file mode 100644 index 00000000..868196c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_compare.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_COMPARE_H_ +#define NNACL_KERNEL_ARITHMETIC_COMPARE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c new file mode 100644 index 00000000..6eaf2741 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.c @@ -0,0 +1,199 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/arithmetic_self.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/arithmetic_self_fp16.h" +#endif + +void ArithmeticSelfGetArithmeticSelfFunction(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { + ArithmeticSelfFunction type_func_table[] = { + {PrimType_Abs, ElementAbs, NULL, ElementAbsInt, NULL}, + {PrimType_Cos, ElementCos, NULL, NULL, NULL}, + {PrimType_Log, ElementLog, NULL, NULL, NULL}, + {PrimType_Log1p, ElementLog1p, NULL, NULL, NULL}, + {PrimType_Square, ElementSquare, NULL, NULL, NULL}, + {PrimType_Sqrt, ElementSqrt, NULL, NULL, NULL}, + {PrimType_Rsqrt, ElementRsqrt, NULL, NULL, NULL}, + {PrimType_Sin, ElementSin, NULL, NULL, NULL}, + {PrimType_LogicalNot, ElementLogicalNot, ElementLogicalNotBool, NULL, NULL}, + {PrimType_Floor, ElementFloor, NULL, NULL, NULL}, + {PrimType_Ceil, ElementCeil, NULL, NULL, NULL}, + {PrimType_Round, ElementRound, NULL, NULL, NULL}, + {PrimType_Neg, ElementNegative, NULL, ElementNegativeInt, NULL}, + {PrimType_Reciprocal, ElementReciprocal, NULL, NULL, NULL}, + {PrimType_Erf, ElementErf, NULL, NULL, NULL}, + {PrimType_IsFinite, NULL, NULL, NULL, ElementIsFinite}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfFunction); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->function_ = type_func_table[i]; + return; + } + } +} + +void ArithmeticSelfGetArithmeticSelfF16Function(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { +#ifdef ENABLE_FP16 + ArithmeticSelfF16Function type_func_table[] = {{PrimType_Abs, ElementAbsFp16}, + {PrimType_Cos, ElementCosFp16}, + {PrimType_Log, ElementLogFp16}, + {PrimType_Square, ElementSquareFp16}, + {PrimType_Sqrt, ElementSqrtFp16}, + {PrimType_Rsqrt, ElementRsqrtFp16}, + {PrimType_Sin, ElementSinFp16}, + {PrimType_LogicalNot, ElementLogicalNotFp16}, + {PrimType_Floor, ElementFloorFp16}, + {PrimType_Ceil, ElementCeilFp16}, + {PrimType_Round, ElementRoundFp16}, + {PrimType_Neg, ElementNegativeFp16}, + {PrimType_Reciprocal, ElementReciprocalFp16}, + {PrimType_Erf, ElementErfFp16}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfF16Function); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->f16_function_ = type_func_table[i]; + return; + } + } +#endif + arithmetic_self->f16_function_.primitive_type_ = primitive_type; + return; +} + +int ArithmeticSelfExecute(ArithmeticSelfStruct *arithmetic_self, int task_id) { + int elements_num = NNACLGetElementNum(arithmetic_self->base_.in_[FIRST_INPUT]); + NNACL_CHECK_TRUE_RET(arithmetic_self->base_.thread_nr_, NNACL_ERR); + int stride = UP_DIV(elements_num, arithmetic_self->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + int offset = task_id * stride; + int count = NNACL_MIN(stride, elements_num - offset); + if (count <= 0) { + return NNACL_OK; + } + + void *in_data = arithmetic_self->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + void *out_data = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + int in_data_type = arithmetic_self->base_.in_[FIRST_INPUT]->data_type_; + int out_data_type = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_type_; + + if (in_data_type == kNumberTypeFloat32 && out_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_float_bool_); + return arithmetic_self->function_.func_float_bool_((float *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_); + return arithmetic_self->function_.func_((float *)in_data + offset, (float *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_bool_); + return arithmetic_self->function_.func_bool_((bool *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_int_); + return arithmetic_self->function_.func_int_((int32_t *)in_data + offset, (int32_t *)out_data + offset, count); + } + +#ifdef ENABLE_FP16 + if (in_data_type == kNumberTypeFloat16) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->f16_function_.func_); + return arithmetic_self->f16_function_.func_((float16_t *)in_data + offset, (float16_t *)out_data + offset, count); + } +#endif + return NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT; +} + +int ArithmeticSelfRun(void *cdata, int task_id, float l, float r) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + return ArithmeticSelfExecute(arithmetic_self, task_id); +} + +int ArithmeticSelfResize(KernelBase *self) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + self->thread_nr_ = arithmetic_self->base_.UpdateThread( + TC_PTYPE(arithmetic_self->op_type_), 1, 1, NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int ArithmeticSelfCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticSelfRun, self, self->thread_nr_); +} + +int ArithmeticSelfPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstTensor, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstScalar, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)malloc(sizeof(ArithmeticSelfStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_self); + ArithmeticSelfGetArithmeticSelfFunction(arithmetic_self, param->type_); + ArithmeticSelfGetArithmeticSelfF16Function(arithmetic_self, param->type_); + arithmetic_self->op_type_ = param->type_; + arithmetic_self->base_.Prepare = ArithmeticSelfPrepare; + arithmetic_self->base_.Resize = ArithmeticSelfResize; + arithmetic_self->base_.Release = DefaultRelease; + arithmetic_self->base_.Compute = ArithmeticSelfCompute; + return (KernelBase *)arithmetic_self; +} + +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeBool, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeInt32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeInt32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log1p, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_IsFinite, kNumberTypeFloat32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat16, CreateArithmeticSelf) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h new file mode 100644 index 00000000..4b8bf8c0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/arithmetic_self.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_SELF_H_ +#define NNACL_KERNEL_ARITHMETIC_SELF_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ArithmeticSelfFunction { + int primitive_type_; + int (*func_)(const float *input, float *output, const int element_size); + int (*func_bool_)(const bool *input, bool *output, const int element_size); + int (*func_int_)(const int *input, int *output, const int element_size); + int (*func_float_bool_)(const float *input, bool *output, const int element_size); +} ArithmeticSelfFunction; + +typedef struct ArithmeticSelfF16Function { + int primitive_type_; +#ifdef ENABLE_FP16 + int (*func_)(const float16_t *input, float16_t *output, int element_size); +#endif +} ArithmeticSelfF16Function; + +typedef struct ArithmeticSelfStruct { + KernelBase base_; + int op_type_; + ArithmeticSelfFunction function_; + ArithmeticSelfF16Function f16_function_; +} ArithmeticSelfStruct; + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c new file mode 100644 index 00000000..ae0d1519 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.c @@ -0,0 +1,134 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/batchnorm_fp16.h" +#endif + +int BatchNormFillParam(BatchNormStruct *batch_norm) { + TensorC *input_tensor = batch_norm->base_.in_[FIRST_INPUT]; + int in_channel = input_tensor->shape_[input_tensor->shape_size_ - 1]; + + TensorC *mean_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int mean_channel = mean_tensor->shape_[mean_tensor->shape_size_ - 1]; + + TensorC *var_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int var_channel = mean_tensor->shape_[var_tensor->shape_size_ - 1]; + + if (in_channel != mean_channel || in_channel != var_channel) { + return NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID; + } + + batch_norm->channel_ = in_channel; + batch_norm->unit_ = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - 1; i++) { + batch_norm->unit_ *= input_tensor->shape_[i]; + } + if (batch_norm->momentum_ < 0.0f) { + batch_norm->momentum_ = 0.0f; + } + return NNACL_OK; +} + +int BatchNormRun(void *cdata, int task_id, float l, float r) { + BatchNormStruct *bn = (BatchNormStruct *)cdata; + void *in_data = bn->base_.in_[FIRST_INPUT]->data_; + void *out_data = bn->base_.out_[OUTPUT_INDEX]->data_; + if (bn->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + BatchNormFp16((float16_t *)in_data, (float16_t *)bn->mean_, (float16_t *)bn->variance_, bn, task_id, + bn->base_.thread_nr_, (float16_t *)out_data); +#endif + } else { + BatchNormFp32((float *)in_data, (float *)bn->mean_, (float *)bn->variance_, bn, task_id, bn->base_.thread_nr_, + (float *)out_data); + } + return NNACL_OK; +} + +int BatchNormReSize(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + int ret = BatchNormFillParam(batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + (void)batch_norm->base_.Release(self); + + batch_norm->mean_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[SECOND_INPUT])); + batch_norm->variance_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[THIRD_INPUT])); + if (batch_norm->mean_ == NULL || batch_norm->variance_ == NULL) { + (void)batch_norm->base_.Release(self); + return NNACL_ERR; + } + + (void)memcpy(batch_norm->mean_, self->in_[SECOND_INPUT]->data_, NNACLGetSize(self->in_[SECOND_INPUT])); + (void)memcpy(batch_norm->variance_, self->in_[THIRD_INPUT]->data_, NNACLGetSize(self->in_[THIRD_INPUT])); + return NNACL_OK; +} + +int BatchNormRelease(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + if (batch_norm->mean_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->mean_); + batch_norm->mean_ = NULL; + } + if (batch_norm->variance_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->variance_); + batch_norm->variance_ = NULL; + } + return NNACL_OK; +} + +int BatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + batch_norm->momentum_ = -1.0f; + batch_norm->epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int BatchNormCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BatchNormRun, self, self->thread_nr_); +} + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type) { + BatchNormStruct *batch_norm = (BatchNormStruct *)malloc(sizeof(BatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_norm); + memset(batch_norm, 0, sizeof(BatchNormStruct)); + batch_norm->data_type_ = data_type; + batch_norm->base_.Prepare = BatchNormPrepare; + batch_norm->base_.Resize = BatchNormReSize; + batch_norm->base_.Release = BatchNormRelease; + batch_norm->base_.Compute = BatchNormCompute; + return (KernelBase *)batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat16, CreateBatchNorm) +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat32, CreateBatchNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h new file mode 100644 index 00000000..e1afa44a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_norm.h @@ -0,0 +1,38 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_NORM_H_ +#define NNACL_KERNEL_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct BatchNormStruct { + KernelBase base_; + int data_type_; + void *mean_; + void *variance_; + float momentum_; + int unit_; + int channel_; + float epsilon_; +} BatchNormStruct; + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type); +int BatchNormRelease(KernelBase *self); +int BatchNormFillParam(BatchNormStruct *batch_norm); + +#endif // NNACL_KERNEL_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c new file mode 100644 index 00000000..f0953fbc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.c @@ -0,0 +1,114 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/batch_to_space.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batch_to_space_parameter.h" + +int BatchToSpaceProcessInput(BatchToSpaceStruct *batch_to_space) { + TensorC *block_shape = batch_to_space->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(block_shape); + NNACL_CHECK_NULL_RETURN_ERR(block_shape->data_); + TensorC *crop = batch_to_space->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(crop); + NNACL_CHECK_NULL_RETURN_ERR(crop->data_); + + if (NNACLGetElementNum(block_shape) < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID; + } + if (NNACLGetElementNum(crop) < COMM_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_CROP_INVALID; + } + + int32_t *block_shape_data = (int32_t *)block_shape->data_; + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + batch_to_space->block_shape_[i] = block_shape_data[i]; + } + + int32_t *crops_data = (int32_t *)crop->data_; + batch_to_space->no_crop_ = true; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + batch_to_space->crops_[i] = crops_data[i]; + if (batch_to_space->crops_[i] != 0) { + batch_to_space->no_crop_ = false; + } + } + return NNACL_OK; +} + +int BatchToSpaceCompute(KernelBase *self) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_to_space); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + size_t data_size = DataTypeCSize(input->data_type_); + if (self->in_size_ == Num1) { + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + + if (self->in_size_ == Num3) { + int ret = BatchToSpaceProcessInput(batch_to_space); + if (ret != NNACL_OK) { + return ret; + } + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + return NNACL_OK; +} + +int BatchToSpaceResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_TRUE_RET(self->in_[FIRST_INPUT]->shape_size_ == COMM_SHAPE_SIZE, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)malloc(sizeof(BatchToSpaceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_to_space); + memset(batch_to_space, 0, sizeof(BatchToSpaceStruct)); + BatchToSpaceParameter *bts_param = (BatchToSpaceParameter *)param; + memcpy(batch_to_space->crops_, bts_param->crops_, sizeof(int32_t) * COMM_SHAPE_SIZE); + memcpy(batch_to_space->block_shape_, bts_param->block_shape_, sizeof(int32_t) * BATCH_TO_SPACE_BLOCK_SHAPE_SIZE); + batch_to_space->base_.Prepare = DefaultPrepare1In1Out; + batch_to_space->base_.Resize = BatchToSpaceResize; + batch_to_space->base_.Release = DefaultRelease; + batch_to_space->base_.Compute = BatchToSpaceCompute; + return (KernelBase *)batch_to_space; +} + +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat32, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat32, CreateBatchToSpace) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h new file mode 100644 index 00000000..3e75a4a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/batch_to_space.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_TO_SPACE_H_ +#define NNACL_KERNEL_BATCH_TO_SPACE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/batch_to_space_parameter.h" + +typedef struct BatchToSpaceStruct { + KernelBase base_; + bool no_crop_; + int32_t crops_[COMM_SHAPE_SIZE]; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; +} BatchToSpaceStruct; + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BATCH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c new file mode 100644 index 00000000..a565a8c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.c @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/biasadd.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/bias_add.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +#define BIAS_ADD_PER_UNIT_LOAD_NUM 2 +#define BIAS_ADD_PER_UNIT_STORE_NUM 1 +#define SPLIT_POINTS_SIZE 32 + +typedef struct BiasAddStruct { + KernelBase base_; + int64_t inner_num_; + int64_t outer_num_; + int64_t total_num_; + bool batch_priority_; + int64_t split_points_[SPLIT_POINTS_SIZE]; + int split_pionts_size_; +} BiasAddStruct; + +int ChooseBiasThreadCuttingStrategy(KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_BiasAdd), BIAS_ADD_PER_UNIT_LOAD_NUM, + BIAS_ADD_PER_UNIT_STORE_NUM, bias_add->total_num_, self->thread_nr_); + if (self->thread_nr_ > SPLIT_POINTS_SIZE) { + self->thread_nr_ = SPLIT_POINTS_SIZE; + } + + bias_add->split_pionts_size_ = 0; + int64_t block_size = 1; + block_size = bias_add->total_num_ / self->thread_nr_; + int64_t remain_data = bias_add->total_num_ - block_size * self->thread_nr_; + int64_t split_point = 0; + while (split_point < bias_add->total_num_) { + bias_add->split_points_[bias_add->split_pionts_size_++] = split_point; + split_point += block_size; + if (remain_data > 0) { + ++split_point; + --remain_data; + } + } + self->thread_nr_ = bias_add->split_pionts_size_; + if (bias_add->inner_num_ >= C64NUM && block_size / bias_add->inner_num_ >= C6NUM) { + bias_add->batch_priority_ = true; + } else { + bias_add->batch_priority_ = false; + } + return NNACL_OK; +} + +int BiasRun(void *cdata, int task_id, float l, float r) { + BiasAddStruct *bias_add = (BiasAddStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + float *input = (float *)(bias_add->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(input); + float *bias = (float *)(bias_add->base_.in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(bias); + float *output = (float *)(bias_add->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output); + + int64_t block_start = bias_add->split_points_[task_id]; + int64_t block_end = bias_add->total_num_; + if ((task_id + 1) < bias_add->split_pionts_size_) { + block_end = bias_add->split_points_[task_id + 1]; + } + BiasAddOpt(input, bias, output, block_start, block_end, bias_add->inner_num_, bias_add->batch_priority_); + return NNACL_OK; +} + +int BiasAddResize(struct KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + TensorC *add_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_FALSE(in_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(add_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(in_tensor->shape_size_ < add_tensor->shape_size_, NNACL_ERR); + + size_t dim_offset = in_tensor->shape_size_ - add_tensor->shape_size_; + bias_add->inner_num_ = 1; + for (size_t i = 0; i < add_tensor->shape_size_; ++i) { + NNACL_CHECK_FALSE(in_tensor->shape_[i + dim_offset] != add_tensor->shape_[i], NNACL_BIAS_ADD_SHAPE_NOT_MATCH); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->inner_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->inner_num_ *= add_tensor->shape_[i]; + } + + bias_add->outer_num_ = 1; + for (size_t i = 0; i < dim_offset; ++i) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->outer_num_ *= in_tensor->shape_[i]; + } + + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(bias_add->inner_num_, bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->total_num_ = bias_add->inner_num_ * bias_add->outer_num_; + return ChooseBiasThreadCuttingStrategy(self); +} + +int BiasAddCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BiasRun, self, self->thread_nr_); +} + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type) { + BiasAddStruct *bias_add = (BiasAddStruct *)malloc(sizeof(BiasAddStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(bias_add); + bias_add->base_.Prepare = DefaultPrepare2In1Out; + bias_add->base_.Resize = BiasAddResize; + bias_add->base_.Release = DefaultRelease; + bias_add->base_.Compute = BiasAddCompute; + return (KernelBase *)bias_add; +} + +REG_KERNEL_CREATOR(PrimType_BiasAdd, kNumberTypeFloat32, CreateBiasAdd) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h new file mode 100644 index 00000000..1a8577c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/biasadd.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BIASADD_H_ +#define NNACL_KERNEL_BIASADD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BIASADD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c new file mode 100644 index 00000000..cfbfaf85 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/cast.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/cast_fp16.h" +#endif + +int CastToFp32(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + float *output_data = (float *)output->data_; + switch (input_data_type) { + case kNumberTypeBool: + BoolToFloat32((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFloat32((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeFloat16: +#ifdef ENABLE_FP16 + Fp16ToFloat32((const float16_t *)(input->data_) + offset, output_data + offset, data_num); +#else + Fp16ToFloat32((const uint16_t *)(input->data_) + offset, output_data + offset, data_num); +#endif + break; + case kNumberTypeInt64: + Int64ToFloat32((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastToFp16(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; +#ifdef ENABLE_FP16 + float16_t *output_data = (float16_t *)output->data_; + switch (input_data_type) { + case kNumberTypeFloat32: + Float32ToFp16((const float *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt64: + Int64ToFp16((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFp16((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeBool: + BoolToFp16((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFp16((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } +#else + if (input_data_type == kNumberTypeFloat32) { + Float32ToFp16((const float *)(input->data_) + offset, (uint16_t *)(output->data_) + offset, data_num); + } else { + return NNACL_ERR; + } +#endif + return NNACL_OK; +} + +int CastToOthers(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + int output_data_type = output->data_type_; + void *output_data = output->data_; + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { + Float32ToInt64((const float *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + Float32ToInt32((const float *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { + Int32ToInt64((const int32_t *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeInt32) { + Int64ToInt32((const int64_t *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { + Float32ToInt16((const float *)(input->data_) + offset, (int16_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { + BoolToInt32((const bool *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeBool) { + Float32ToBool((const float *)(input->data_) + offset, (bool *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeUInt8) { + Float32ToUint8((const float *)(input->data_) + offset, (uint8_t *)(output_data) + offset, data_num); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastLaunch(void *cdata, int task_id, float l, float r) { + CastStruct *cast = (CastStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cast); + + NNACL_CHECK_FALSE(cast->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(cast->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + TensorC *in = cast->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + NNACL_CHECK_NULL_RETURN_ERR(in->data_); + TensorC *out = cast->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + NNACL_CHECK_NULL_RETURN_ERR(out->data_); + + int stride = cast->stride_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + + int data_num = MSMIN(stride, cast->data_num_ - task_id * stride); + if (data_num <= 0) { + return NNACL_OK; + } + + int offset = task_id * stride; + int input_data_type = in->data_type_; + int output_data_type = out->data_type_; + if (input_data_type == output_data_type) { + size_t datalen = DataTypeCSize((TypeIdC)input_data_type); + memcpy((int8_t *)(out->data_) + offset * datalen, (int8_t *)(in->data_) + offset * datalen, data_num * datalen); + return NNACL_OK; + } + + if (output_data_type == kNumberTypeFloat32) { + return CastToFp32(in, out, offset, data_num); + } else if (output_data_type == kNumberTypeFloat16) { + return CastToFp16(in, out, offset, data_num); + } else { + return CastToOthers(in, out, offset, data_num); + } + return NNACL_OK; +} + +int cast_prepare(struct KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +// Kernel resize input shape +int cast_resize(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int data_num = NNACLGetElementNum(in_tensor); + if (data_num == 0) { + return NNACL_OK; + } + + cast->data_num_ = data_num; + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + // update thread num + cast->base_.thread_nr_ = cast->base_.UpdateThread( + TC_PTYPE(PrimType_Cast), 1, 1, NNACLGetElementNum(cast->base_.out_[FIRST_INPUT]), cast->base_.thread_nr_); + cast->stride_ = UP_DIV(data_num, cast->base_.thread_nr_); + return NNACL_OK; +} + +int cast_release(struct KernelBase *self) { return NNACL_OK; } + +// Cast Op Compute +int cast_compute(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + if (cast->data_num_ == 0) { + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CastLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCast(OpParameter *param, int data_type) { + CastStruct *cast = (CastStruct *)malloc(sizeof(CastStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(cast); + memset(cast, 0, sizeof(CastStruct)); + cast->base_.Prepare = cast_prepare; + cast->base_.Resize = cast_resize; + cast->base_.Release = cast_release; + cast->base_.Compute = cast_compute; + cast->stride_ = 0; + cast->data_num_ = 0; + return (KernelBase *)cast; +} + +// todo register kernel diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h new file mode 100644 index 00000000..1312c4aa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/cast.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CAST_H_ +#define NNACL_KERNEL_CAST_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct CastStruct { + KernelBase base_; + int stride_; + int data_num_; +} CastStruct; + +KernelBase *CreateCast(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CAST_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c new file mode 100644 index 00000000..1d8ef8e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.c @@ -0,0 +1,123 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/clip.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/clip_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int GetClipMinMaxValue(TensorC *tensor, float *data) { + NNACL_CHECK_NULL_RETURN_ERR(tensor); + switch (tensor->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + *data = *((float *)tensor->data_); + break; + case kNumberTypeInt: + case kNumberTypeInt32: + *data = *((int *)tensor->data_); + break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipResize(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + clip->base_.thread_nr_ = clip->base_.UpdateThread( + TC_PTYPE(PrimType_Clip), 1, 1, NNACLGetElementNum(clip->base_.out_[FIRST_INPUT]), clip->base_.thread_nr_); + + clip->length_ = NNACLGetElementNum(clip->base_.in_[FIRST_INPUT]); + clip->stride_ = UP_DIV(clip->length_, clip->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(clip->stride_, clip->base_.thread_nr_, NNACL_ERR); + return NNACL_OK; +} + +int ClipImpl(void *cdata, int task_id, float l, float r) { + ClipStruct *clip = (ClipStruct *)cdata; + void *in = clip->base_.in_[FIRST_INPUT]->data_; + void *out = clip->base_.out_[FIRST_INPUT]->data_; + + int stride = clip->stride_ * task_id; + int count = NNACL_MIN(clip->stride_, clip->length_ - stride); + if (count <= 0) { + return NNACL_OK; + } + + switch (clip->base_.in_[FIRST_INPUT]->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: { + return Fp32Clip((float *)in + stride, count, (float *)out + stride, clip->min_val_, clip->max_val_); + } break; + case kNumberTypeInt: + case kNumberTypeInt32: { + return Int32Clip((int *)in + stride, count, (int *)out + stride, (int)clip->min_val_, (int)clip->max_val_); + } break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipCompute(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + ClipParameter *param = (ClipParameter *)clip->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + clip->min_val_ = param->min_val_; + clip->max_val_ = param->max_val_; + + int ret = NNACL_OK; + if (clip->base_.in_size_ > ONE_TENSOR) { + TensorC *min_tensor = clip->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(min_tensor); + NNACL_CHECK_NULL_RETURN_ERR(min_tensor->data_); + ret = GetClipMinMaxValue(min_tensor, &(clip->min_val_)); + } + if (clip->base_.in_size_ > TWO_TENSOR) { + TensorC *max_tensor = clip->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(max_tensor); + NNACL_CHECK_NULL_RETURN_ERR(max_tensor->data_); + ret = GetClipMinMaxValue(max_tensor, &(clip->max_val_)); + } + if (ret != NNACL_OK) { + return ret; + } + if (clip->min_val_ >= clip->max_val_) { + return NNACL_CLIP_MINMAX_VALUE_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ClipImpl, clip, clip->base_.thread_nr_); +} + +KernelBase *CreateClip(OpParameter *param, int data_type) { + ClipStruct *clip = (ClipStruct *)malloc(sizeof(ClipStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(clip); + clip->base_.Prepare = DefaultPrepare1In1Out; + clip->base_.Resize = ClipResize; + clip->base_.Release = DefaultRelease; + clip->base_.Compute = ClipCompute; + return (KernelBase *)clip; +} + +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat32, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt32, CreateClip) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h new file mode 100644 index 00000000..23f91ee2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/clip.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CLIP_H_ +#define NNACL_KERNEL_CLIP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ClipStruct { + KernelBase base_; + float min_val_; + float max_val_; + int length_; + int stride_; +} ClipStruct; + +KernelBase *CreateClip(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CLIP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c new file mode 100644 index 00000000..4f382e86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.c @@ -0,0 +1,287 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" + +#define kConcatMinCostPerThread 16384 + +int DoConcat(ConcatStruct *concat, int task_id) { + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id > concat->block_size_, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t start = concat->block_splits_[task_id]; + int64_t end = task_id < (concat->block_size_ - 1) ? concat->block_splits_[task_id + 1] : all_bytes; + int64_t start_row = start / concat->inner_sizes_[concat->base_.in_size_]; + int64_t end_row = end / concat->inner_sizes_[concat->base_.in_size_]; + + size_t src_buf_size = concat->base_.in_size_ * sizeof(uint8_t *); + NNACL_CHECK_MALLOC_SIZE(src_buf_size); + uint8_t **src = (uint8_t **)concat->base_.env_->Alloc(concat->base_.env_->allocator_, src_buf_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(src); + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (concat->is_with_data_[i]) { + src[i] = concat->inputs_ptr_[i] + start_row * concat->inner_sizes_[i]; + } + } + uint8_t *out = concat->output_ + start; + + int input_index = concat->block_boundary_infos_[task_id].begin_input_; + int end_index = concat->block_boundary_infos_[task_id].end_input_; + if (start_row == end_row) { + if (input_index == end_index) { + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, + concat->block_boundary_infos_[task_id].end_point_ - concat->block_boundary_infos_[task_id].begin_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + out += size; + ++input_index; + for (; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[input_index], concat->block_boundary_infos_[task_id].end_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + for (int i = 0; i < input_index; ++i) { + src[i] += concat->inner_sizes_[i]; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + src[input_index] += concat->inner_sizes_[input_index]; + out += size; + ++input_index; + for (; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + ++start_row; + for (; start_row < end_row; ++start_row) { + for (input_index = 0; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + } + for (input_index = 0; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[end_index], concat->block_boundary_infos_[task_id].end_point_); + + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; +} + +int ConcatRun(void *cdata, int task_id, float l, float r) { + ConcatStruct *concat = (ConcatStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat); + return DoConcat(concat, task_id); +} + +int InitConcatDynamicStatus(ConcatStruct *concat) { + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + size_t i = 0; + int64_t output_inner_size = 0; + for (; i < concat->base_.in_size_; i++) { + TensorC *t = concat->base_.in_[i]; + NNACL_CHECK_FALSE(param->axis_ >= t->shape_size_, NNACL_CONCAT_AXIS_INVALID); + int64_t outer_size = 1; + for (int j = 0; j < param->axis_; ++j) { + outer_size *= t->shape_[j]; + } + int inner_size = DataTypeCSize(concat->data_type_); + NNACL_CHECK_TRUE_RET(inner_size > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + for (int j = param->axis_; j < t->shape_size_; ++j) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(inner_size, t->shape_[j], NNACL_CONCAT_SHAPE_INVALID); + inner_size *= t->shape_[j]; + } + if (i == 0) { + concat->outer_size_ = outer_size; + } else { + NNACL_CHECK_TRUE_RET(concat->outer_size_ == outer_size, NNACL_CONCAT_SHAPE_INVALID); + } + if (inner_size == 0) { + concat->is_with_data_[i] = false; + concat->inner_sizes_[i] = inner_size; + continue; + } + concat->is_with_data_[i] = true; + concat->inner_sizes_[i] = inner_size; + output_inner_size += inner_size; + } + concat->inner_sizes_[i] = output_inner_size; + return NNACL_OK; +} + +void ComputeConcatUnitBoundary(ConcatStruct *concat, int64_t *pre_sum, int offset, int *input, int64_t *point) { + size_t index = 0; + for (; index < concat->base_.in_size_; ++index) { + if (offset < pre_sum[index]) { + break; + } + } + *input = index; + *point = concat->inner_sizes_[index] - (pre_sum[index] - offset); +} + +int ChooseConcatThreadCuttingStrategy(ConcatStruct *concat) { + NNACL_CHECK_TRUE_RET(concat->base_.thread_nr_ > 0, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t thread_count = MSMAX(1, MSMIN(all_bytes / kConcatMinCostPerThread, concat->base_.thread_nr_)); + + NNACL_CHECK_ZERO_RETURN_ERR(thread_count); + int64_t block_size = all_bytes / thread_count; + int64_t remain_byte = all_bytes - block_size * thread_count; + int64_t *pre_sum = + (int64_t *)concat->base_.env_->Alloc(concat->base_.env_->allocator_, concat->base_.in_size_ * sizeof(int64_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pre_sum); + int64_t init_sum = 0; + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + init_sum += concat->inner_sizes_[i]; + pre_sum[i] = init_sum; + } + + concat->block_size_ = 0; + + int64_t block_spilt = 0; + while (block_spilt < all_bytes) { + concat->block_splits_[concat->block_size_] = block_spilt; + block_spilt += block_size; + if (remain_byte > 0) { + ++block_spilt; + --remain_byte; + } + int64_t start = concat->block_splits_[concat->block_size_]; + int64_t end = block_spilt > all_bytes ? all_bytes : block_spilt; + int64_t start_offset = start - DOWN_ROUND(start, concat->inner_sizes_[concat->base_.in_size_]); + int64_t end_offset = end - DOWN_ROUND(end, concat->inner_sizes_[concat->base_.in_size_]); + ConcatBlockBoundaryInfo block_boundary_info; + ComputeConcatUnitBoundary(concat, pre_sum, start_offset, &block_boundary_info.begin_input_, + &block_boundary_info.begin_point_); + ComputeConcatUnitBoundary(concat, pre_sum, end_offset, &block_boundary_info.end_input_, + &block_boundary_info.end_point_); + concat->block_boundary_infos_[concat->block_size_] = block_boundary_info; + concat->block_size_++; + } + + concat->base_.thread_nr_ = concat->block_size_; + concat->base_.env_->Free(concat->base_.env_->allocator_, pre_sum); + return NNACL_OK; +} + +int ConcatResize(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + param->axis_ = param->axis_ >= 0 ? param->axis_ : self->in_[FIRST_INPUT]->shape_size_ + param->axis_; + NNACL_CHECK_FALSE(param->axis_ < 0, NNACL_CONCAT_AXIS_INVALID); + NNACL_CHECK_FALSE(param->axis_ >= self->in_[FIRST_INPUT]->shape_size_, NNACL_CONCAT_AXIS_INVALID); + + int ret = InitConcatDynamicStatus(concat); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + return ChooseConcatThreadCuttingStrategy(concat); +} + +int ConcatPepare(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + + concat->inputs_ptr_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(uint8_t *)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_); + concat->is_with_data_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_ERR(concat->is_with_data_); + concat->inner_sizes_ = + self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(int64_t)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inner_sizes_); + + return NNACL_OK; +} + +int ConcatRelease(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->inputs_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inputs_ptr_); + } + if (concat->is_with_data_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->is_with_data_); + } + if (concat->inner_sizes_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inner_sizes_); + } + return NNACL_OK; +} + +int ConcatCompute(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + for (size_t i = 0; i < self->in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[i]->data_); + concat->inputs_ptr_[i] = self->in_[i]->data_; + } + + concat->output_ = self->out_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatRun, self, self->thread_nr_); +} + +KernelBase *CreateConcat(OpParameter *param, int data_type) { + ConcatStruct *concat = (ConcatStruct *)malloc(sizeof(ConcatStruct)); + NNACL_CHECK_NULL_RETURN_NULL(concat); + memset(concat, 0, sizeof(ConcatStruct)); + concat->data_type_ = kNumberTypeFloat32; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatCompute; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeBool, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeInt32, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat32, CreateConcat) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h new file mode 100644 index 00000000..cdc201f1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/concat.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONCAT_H_ +#define NNACL_KERNEL_CONCAT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ConcatBlockBoundaryInfo { + int begin_input_; // input-index of upper boundary + int end_input_; // input-index of lower boundary. + int64_t begin_point_; // offset of begin-input. + int64_t end_point_; // required size of end-input. +} ConcatBlockBoundaryInfo; + +typedef struct ConcatStruct { + KernelBase base_; + int64_t outer_size_; + uint8_t *output_; + TypeIdC data_type_; + + bool *is_with_data_; /* size = in_tensor_size */ + uint8_t **inputs_ptr_; /* size = in_tensor_size */ + int64_t *inner_sizes_; // byte-inner-size (including axis) of each input and the last one is output's. + + ConcatBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; /* dynamic block size */ + int64_t block_splits_[MAX_THREAD_NUM]; /* dynamic block size */ + size_t block_size_; /* dynamic block size = actual thread number */ +} ConcatStruct; + +KernelBase *CreateConcat(OpParameter *param, int data_type); +int DoConcat(ConcatStruct *concat, int task_id); +int ConcatPepare(KernelBase *self); +int ConcatRelease(KernelBase *self); +int ConcatResize(KernelBase *self); + +#endif // NNACL_KERNEL_CONCAT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c new file mode 100644 index 00000000..f52b7d28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.c @@ -0,0 +1,365 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_1x1.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +int Conv1x1Run(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->col_ - total_thead_stride_; + int cur_oc = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_oc <= 0) { + return NNACL_OK; + } + + TensorC *out_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *bias = conv_1x1->conv_.bias_data_ == NULL + ? NULL + : (float *)conv_1x1->conv_.bias_data_ + conv_1x1->thread_stride_ * task_id; + float *weight = (float *)conv_1x1->conv_.packed_weight_ + total_thead_stride_ * matmul->deep_; + + if (out_tensor->format_ == Format_NC4HW4) { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_ * matmul->row_, bias, + matmul->act_type_, matmul->deep_, matmul->row_, cur_oc, matmul->row_, OutType_NC4HW4); + } else { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_, bias, matmul->act_type_, + matmul->deep_, matmul->row_, cur_oc, matmul->col_, OutType_Nhwc); + } + return NNACL_OK; +} + +void Conv1x1PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr, dst_ptr, row, col); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr, dst_ptr, row, col); +#else + RowMajor2Col12Major(src_ptr, dst_ptr, row, col); +#endif +} + +int Conv1x1RunHw(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + TensorC *output_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->row_ - total_thead_stride_; + int cur_hw_ = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_hw_ <= 0) { + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = task_id * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, matmul->deep_, NNACL_ERR); + float *thread_input_ptr = conv_1x1->input_ptr_ + total_thead_stride_ * matmul->deep_; + float *thread_pack_input = conv_1x1->pack_input_ + total_row_tile_ * matmul->deep_; + float *thread_output_ptr = NULL; + if (output_tensor->format_ != Format_NC4HW4) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->col_, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * matmul->col_; + } else { + int col_min = MSMIN(matmul->col_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, col_min, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * col_min; + } + float *cur_intput = thread_input_ptr; + float *cur_output = thread_output_ptr; + float *bias = (float *)conv_1x1->conv_.bias_data_; + for (int i = 0; i < cur_hw_; i += conv_1x1->row_tile_) { + int cur_rows = (cur_hw_ - i >= conv_1x1->row_tile_) ? conv_1x1->row_tile_ : (cur_hw_ - i); + Conv1x1PackMatmulInput(cur_intput, thread_pack_input, cur_rows, matmul->deep_); + if (output_tensor->format_ == Format_NC4HW4) { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->row_, OutType_NC4HW4); + cur_output += conv_1x1->row_tile_ * MSMIN(matmul->col_, C4NUM); + } else { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->col_, OutType_Nhwc); + cur_output += conv_1x1->row_tile_ * matmul->col_; + } + cur_intput += conv_1x1->row_tile_ * matmul->deep_; + } + + return NNACL_OK; +} + +void Conv1x1PackWeight(ConvolutionBaseStruct *conv) { + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + ConvComputeParam *compute = &conv->compute_; + NNACL_CHECK_NULL_RETURN_VOID(compute); + + if (compute->in_c_ <= 0 || compute->out_c_ <= 0) { + return; + } + + void *origin_weight = conv->base_.train_session_ ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + RowMajor2Col16Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#else + RowMajor2Col8Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#endif +} + +int Conv1x1MallocWeightBiasData(ConvolutionBaseStruct *conv) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + int size = conv->compute_.in_c_ * UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + if (!conv->base_.train_session_) { + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + size = UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, size); + } + return NNACL_OK; +} + +void Conv1x1FreeTmpBuffer(Convolution1x1Struct *conv_1x1) { + if (conv_1x1->pre_trans_input_ && conv_1x1->input_ptr_ != NULL) { + conv_1x1->conv_.base_.env_->Free(conv_1x1->conv_.base_.env_->allocator_, conv_1x1->input_ptr_); + conv_1x1->input_ptr_ = NULL; + } + return; +} + +int InitConv1x1MatmulParam(Convolution1x1Struct *conv_1x1) { + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->output_h_, conv_param->output_w_, NNACL_ERR); + conv_1x1->matmul_param_.row_ = conv_param->output_h_ * conv_param->output_w_; + conv_1x1->matmul_param_.col_ = conv_param->output_channel_; + conv_1x1->matmul_param_.deep_ = conv_param->input_channel_; + conv_1x1->matmul_param_.row_align_ = UP_ROUND(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_); + conv_1x1->matmul_param_.col_align_ = UP_ROUND(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_); + conv_1x1->matmul_param_.act_type_ = conv_param->act_type_; + return NNACL_OK; +} + +int InitConv1x1Param(Convolution1x1Struct *conv_1x1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->row_tile_, conv_1x1->conv_.base_.thread_nr_, NNACL_ERR); + if ((conv_1x1->matmul_param_.row_ > (conv_1x1->row_tile_ * conv_1x1->conv_.base_.thread_nr_)) && + (conv_1x1->matmul_param_.row_ > conv_1x1->matmul_param_.col_)) { + conv_1x1->multi_thread_by_hw_ = true; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->row_tile_; + } else { + conv_1x1->multi_thread_by_hw_ = false; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->col_tile_; + } + + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + conv_1x1->pre_trans_input_ = + (conv_param->pad_u_ != 0 || conv_param->pad_l_ != 0 || conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1); + if (conv_1x1->pre_trans_input_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + conv_1x1->input_ptr_ = (float *)(conv_1x1->conv_.base_.env_->Alloc( + conv_1x1->conv_.base_.env_->allocator_, + conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->input_ptr_); + memset(conv_1x1->input_ptr_, 0, conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float)); + } + + return NNACL_OK; +} + +int Convolution1x1Resize(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + Conv1x1FreeTmpBuffer(conv_1x1); + int error_code = ConvBasePrepare(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1MatmulParam(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1Param(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + return NNACL_OK; +} + +int Convolution1x1Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + +#ifdef ENABLE_AVX + conv_1x1->row_tile_ = C6NUM; + conv_1x1->col_tile_ = C16NUM; +#elif defined(ENABLE_SSE) + conv_1x1->row_tile_ = C4NUM; + conv_1x1->col_tile_ = C8NUM; +#elif defined(ENABLE_ARM32) + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C4NUM; +#else + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C8NUM; +#endif + + if (self->train_session_) { + int output_tile_size = UP_ROUND(conv_1x1->conv_.compute_.out_c_, conv_1x1->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.compute_.in_c_, output_tile_size, NNACL_ERR); + size_t size = conv_1x1->conv_.compute_.in_c_ * output_tile_size * sizeof(float); + conv_1x1->conv_.base_.work_size_ = size; + } + + int error_code = ConvBaseInitConvWeightBias(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + return NNACL_OK; +} + +int Convolution1x1Release(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + Conv1x1FreeTmpBuffer(conv_1x1); + ConvBaseRelease(&conv_1x1->conv_); + return NNACL_OK; +} + +int Convolution1x1Compute(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + float *src_in = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int pack_input_size = 0; + if (conv_1x1->multi_thread_by_hw_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.base_.thread_nr_, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = conv_1x1->conv_.base_.thread_nr_ * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = total_row_tile_ * conv_1x1->matmul_param_.deep_; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_align_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = conv_1x1->matmul_param_.row_align_ * conv_1x1->matmul_param_.deep_; + } + conv_1x1->pack_input_ = + (float *)conv_1x1->conv_.base_.env_->Alloc(conv_1x1->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->pack_input_); + + int ret = ConvBaseRepackWeight(&conv_1x1->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.col_, NNACL_ERR); + int matmul_size = conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.col_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, matmul_size, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_h_, conv_param->input_w_, NNACL_ERR); + int conv_input_hw = conv_param->input_h_ * conv_param->input_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_hw, conv_param->input_channel_, NNACL_ERR); + int conv_input_bhw = conv_input_hw * conv_param->input_channel_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, conv_input_bhw, NNACL_ERR); + for (int batch_index = 0; batch_index < conv_param->input_batch_; batch_index++) { + conv_1x1->output_ptr_ = src_out + batch_index * matmul_size; + float *tmp_in = src_in + batch_index * conv_input_bhw; + if (conv_1x1->pre_trans_input_) { + Conv1x1InputPack(tmp_in, conv_1x1->input_ptr_, conv_param, sizeof(float)); + } else { + conv_1x1->input_ptr_ = tmp_in; + } + if (conv_1x1->multi_thread_by_hw_) { + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1RunHw, self, self->thread_nr_); + } else { + Conv1x1PackMatmulInput(conv_1x1->input_ptr_, conv_1x1->pack_input_, conv_1x1->matmul_param_.row_, + conv_1x1->matmul_param_.deep_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1Run, self, self->thread_nr_); + } + if (ret != NNACL_OK) { + break; + } + } + + if (conv_1x1->pack_input_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_1x1->pack_input_); + conv_1x1->pack_input_ = NULL; + } + return ret; +} + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param) { + Convolution1x1Struct *conv1x1 = (Convolution1x1Struct *)malloc(sizeof(Convolution1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv1x1); + memset(conv1x1, 0, sizeof(Convolution1x1Struct)); + + conv1x1->conv_.is_sharing_pack_ = false; + conv1x1->conv_.malloc_weight_bias_ = Conv1x1MallocWeightBiasData; + conv1x1->conv_.pack_weight_ = Conv1x1PackWeight; + + conv1x1->conv_.base_.Resize = Convolution1x1Resize; + conv1x1->conv_.base_.Prepare = Convolution1x1Prepare; + conv1x1->conv_.base_.Release = Convolution1x1Release; + conv1x1->conv_.base_.Compute = Convolution1x1Compute; + + return (ConvolutionBaseStruct *)conv1x1; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h new file mode 100644 index 00000000..bd26ed48 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_1x1.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_1X1_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" + +typedef struct Convolution1x1Struct { + ConvolutionBaseStruct conv_; + MatMulParameter matmul_param_; + int row_tile_; + int col_tile_; + bool pre_trans_input_; + float *input_ptr_; + float *output_ptr_; + float *pack_input_; + bool multi_thread_by_hw_; + int thread_stride_; +} Convolution1x1Struct; + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_1X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c new file mode 100644 index 00000000..8b49cbb3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param) { + compute->stride_h_ = conv_param->stride_h_; + compute->stride_w_ = conv_param->stride_w_; + compute->dilation_h_ = conv_param->dilation_h_; + compute->dilation_w_ = conv_param->dilation_w_; + compute->pad_u_ = conv_param->pad_u_; + compute->pad_d_ = conv_param->pad_d_; + compute->pad_l_ = conv_param->pad_l_; + compute->pad_r_ = conv_param->pad_r_; + + compute->in_c_ = conv_param->input_channel_; + compute->out_c_ = conv_param->output_channel_; + + compute->kernel_h_ = conv_param->kernel_h_; + compute->kernel_w_ = conv_param->kernel_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_h_, compute->kernel_w_, NNACL_ERR); + compute->kernel_hw_ = compute->kernel_h_ * compute->kernel_w_; + + return NNACL_OK; +} + +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *input = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + conv_param->input_batch_ = NNACLGetBatch(input); + conv_param->input_h_ = NNACLGetHeight(input); + conv_param->input_w_ = NNACLGetWidth(input); + conv_param->input_channel_ = NNACLGetChannel(input); + conv_param->output_batch_ = NNACLGetBatch(output); + conv_param->output_h_ = NNACLGetHeight(output); + conv_param->output_w_ = NNACLGetWidth(output); + conv_param->output_channel_ = NNACLGetChannel(output); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(input); + compute->in_h_ = NNACLGetHeight(input); + compute->in_w_ = NNACLGetWidth(input); + compute->in_c_ = NNACLGetChannel(input); + NNACL_CHECK_FALSE(compute->in_c_ != conv_param->input_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_h_, compute->in_w_, NNACL_ERR); + compute->in_hw_ = compute->in_h_ * compute->in_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_, compute->in_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_ * compute->in_n_, compute->in_c_, NNACL_ERR); + + compute->out_n_ = NNACLGetBatch(output); + compute->out_h_ = NNACLGetHeight(output); + compute->out_w_ = NNACLGetWidth(output); + compute->out_c_ = NNACLGetChannel(output); + NNACL_CHECK_FALSE(compute->out_c_ != conv_param->output_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_h_, compute->out_w_, NNACL_ERR); + compute->out_hw_ = compute->out_h_ * compute->out_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_, compute->out_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_ * compute->out_n_, compute->out_c_, NNACL_ERR); + + return ConvBaseUpdateParamInfo(compute, conv_param); +} + +void ConvBaseRelease(ConvolutionBaseStruct *conv) { + if (!conv->base_.train_session_) { + if (!conv->is_sharing_pack_) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->packed_weight_); + } else { + conv->free_sharing_weight_(conv->shaing_manager_, conv->packed_weight_); + } + conv->packed_weight_ = NULL; + } + + if (conv->bias_data_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->bias_data_); + conv->bias_data_ = NULL; + } +} + +int ConvBasePrepare(ConvolutionBaseStruct *conv) { + NNACL_CHECK_FALSE(conv->base_.in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(conv->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + conv->out_format_ = conv->base_.out_[OUTPUT_INDEX]->format_; + return ConvBaseUpdateComputeInfo(conv); +} + +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_VOID(conv); + + if (conv->base_.in_[SECOND_INPUT]->data_ != NULL) { + conv->origin_weight_ = conv->base_.in_[SECOND_INPUT]->data_; + } + + if (conv->base_.in_size_ == THREE_TENSOR && conv->base_.in_[THIRD_INPUT]->data_ != NULL) { + conv->origin_bias_ = conv->base_.in_[THIRD_INPUT]->data_; + } +} + +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv) { + if (conv->base_.train_session_) { + ConvBaseUpdateOriginWeightAndBias(conv); + } + + /* check weight shape done */ + if (!CheckInferShapeDone(&conv->base_.in_[SECOND_INPUT], ONE_TENSOR, NULL, 0)) { + return NNACL_OK; + } + + int ret = conv->malloc_weight_bias_(conv); + if (ret != NNACL_OK) { + return ret; + } + + if ((conv->base_.in_size_ == THREE_TENSOR) && (conv->origin_bias_ != NULL)) { + memcpy(conv->bias_data_, conv->origin_bias_, NNACLGetSize(conv->base_.in_[THIRD_INPUT])); + } + + if (!conv->base_.train_session_) { + if (conv->weight_is_packed_) { + return NNACL_OK; + } + if (conv->origin_weight_ != NULL) { + conv->pack_weight_(conv); + } else { + conv->is_repack_ = true; + } + } + return NNACL_OK; +} + +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + int resize_in_channel = NNACLGetChannel(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + int filter_in_channel = NNACLGetChannel(filter_tensor); + if (filter_in_channel != resize_in_channel) { + return NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH; + } + return NNACL_OK; +} + +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + bool const_fit = weight_tensor->category_ != ConstTensor && weight_tensor->category_ != ConstScalar; + bool group_fit = ((ConvParameter *)conv->base_.param_)->group_ > 1; + bool sharing_fit = conv->get_sharing_weight_ == NULL; + + void *data = NULL; + if (sharing_fit || const_fit || group_fit) { + if (data_size <= 0) { + return NULL; + } + data = conv->base_.env_->Alloc(conv->base_.env_->allocator_, data_size); + conv->weight_is_packed_ = false; + conv->is_sharing_pack_ = false; + } else { + data = conv->get_sharing_weight_(conv->shaing_manager_, weight_tensor->data_, data_size, &conv->weight_is_packed_); + } + return data; +} + +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + + conv->origin_weight_ = conv->origin_weight_ != NULL ? conv->origin_weight_ : conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv->origin_weight_); + + if (conv->packed_weight_ == NULL) { + int ret = ConvBaseInitConvWeightBias(conv); + if (ret != NNACL_OK) { + return ret; + } + } + + if (conv->is_repack_ || conv->base_.train_session_) { + if (conv->base_.train_session_) { + conv->packed_weight_ = (float *)conv->base_.workspace_; + memset(conv->packed_weight_, 0, conv->base_.work_size_); + } else { + conv->is_repack_ = false; + } + conv->pack_weight_(conv); + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h new file mode 100644 index 00000000..c932b602 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_base.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" + +#define ConvMinBlock 1 + +typedef struct ConvolutionBaseStruct { + KernelBase base_; + ConvComputeParam compute_; + bool weight_is_packed_; + bool is_repack_; + bool infershape_done_; + bool use_batch_cut_flag_; + FormatC out_format_; + + void *packed_weight_; + void *bias_data_; + void *origin_weight_; // do not Free + void *origin_bias_; // do not Free + + void (*init_global_variable_)(struct ConvolutionBaseStruct *conv_im2col); + int (*malloc_weight_bias_)(struct ConvolutionBaseStruct *conv_base); + void (*pack_weight_)(struct ConvolutionBaseStruct *conv_base); + int (*run_impl_)(struct ConvolutionBaseStruct *conv, int task_id); + + bool is_sharing_pack_; + void *shaing_manager_; + void (*free_sharing_weight_)(void *manager, void *tensor_data); + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); +} ConvolutionBaseStruct; + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param); +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv); +void ConvBaseRelease(ConvolutionBaseStruct *conv); +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv); +int ConvBasePrepare(ConvolutionBaseStruct *conv); +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv); +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv); +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv); +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size); + +#endif // NNACL_KERNEL_CONVOLLUTION_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c new file mode 100644 index 00000000..7ee4aa2c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.c @@ -0,0 +1,365 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_delegate.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/kernel/group_convolution.h" +#include "nnacl_c/kernel/convolution_depthwise.h" +#include "nnacl_c/kernel/convolution_1x1.h" +#include "nnacl_c/kernel/convolution_im2col.h" +#include "nnacl_c/kernel/convolution_winograd.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" +#include "nnacl_c/kernel/convolution_depthwise_sw.h" +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_1x1.h" +#include "nnacl_c/kernel/convolution_sw_avx.h" +#include "nnacl_c/kernel/convolution_depthwise_sw_avx.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_depthwise_indirect.h" +#include "nnacl_c/kernel/convolution_sw_arm64.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" +#endif +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/kernel/convolution_depthwise_3x3.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#endif + +#define MaxDwConvSWSize 32 + +float *ConvolutionDelegateCopyData(const TensorC *tensor) { + NNACL_CHECK_NULL_RETURN_NULL(tensor); + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + float *data = (float *)malloc(NNACLGetSize(tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(data); + + (void)memcpy(data, tensor->data_, NNACLGetSize(tensor)); + return data; +} + +int ConvolutionDelegateGetWeightData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_ == NULL) { + return NNACL_OK; + } + if (convolution_delegate->conv_.infershape_done_) { + convolution_delegate->origin_weight_ = convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = false; + return NNACL_OK; + } + convolution_delegate->origin_weight_ = + ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[SECOND_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetBiasData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_size_ != THREE_TENSOR) { + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + if (convolution_delegate->conv_.infershape_done_) { + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + convolution_delegate->origin_bias_ = convolution_delegate->conv_.base_.in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + convolution_delegate->origin_bias_ = ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetWeightAndBias(ConvolutionDelegateStruct *convolution_delegate) { + int ret = ConvolutionDelegateGetWeightData(convolution_delegate); + if (ret != NNACL_OK) { + return ret; + } + + return ConvolutionDelegateGetBiasData(convolution_delegate); +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNC4KernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + /* runtime nc4hw4 pass + * arm64: conv1x1 conv_Im2col support nc4 + * Avx: conv_Im2col support nc4 + * */ + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + +#ifdef ENABLE_ARM64 + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + ConvolutionBaseStruct *conv1x1 = CreateConvolution1x1(conv_param); + return conv1x1; + } +#endif + +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) + ConvolutionBaseStruct *conv_im2col = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + return conv_im2col; +#endif + + return NULL; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNHWCKernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + + int out_unit; + if (CheckIfUseWinograd(&out_unit, conv_param)) { + conv = CreateConvolutionWinograd(conv_param, out_unit); + } + +#ifdef ENABLE_AVX + if (conv == NULL && CheckAvxUseSW1x1Conv(conv_param)) { + conv = CreateConvolutionSW1x1(conv_param, convolution_delegate->input_const_, convolution_delegate->weight_const_); + } + + if (conv == NULL && CheckAvxUseSWConv(conv_param, convolution_delegate->conv_.base_.thread_nr_)) { + conv = CreateConvolutionSWAVX(conv_param); + } +#endif + +#ifdef ENABLE_ARM64 + if (conv == NULL && CheckArm64UseSWConv(conv_param)) { + conv = CreateConvolutionSWARM64(conv_param); + } +#endif + + if (conv == NULL) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + conv = CreateConvolution1x1(conv_param); + } else { + conv = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + } + } + return conv; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvolutionSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvolutionBaseStruct *conv; + if (convolution_delegate->conv_.base_.out_[OUTPUT_INDEX]->format_ == Format_NC4HW4) { + conv = ConvolutionDelegateConvNC4KernelSelect(convolution_delegate); + } else { + conv = ConvolutionDelegateConvNHWCKernelSelect(convolution_delegate); + } + if (conv == NULL) { + return NULL; + } + + conv->base_.InferShape = convolution_delegate->conv_.base_.InferShape; + conv->base_.UpdateThread = convolution_delegate->conv_.base_.UpdateThread; + conv->base_.env_ = convolution_delegate->conv_.base_.env_; + conv->base_.param_ = convolution_delegate->conv_.base_.param_; + conv->base_.thread_nr_ = convolution_delegate->conv_.base_.thread_nr_; + conv->base_.train_session_ = convolution_delegate->conv_.base_.train_session_; + conv->base_.in_ = convolution_delegate->conv_.base_.in_; + conv->base_.in_size_ = convolution_delegate->conv_.base_.in_size_; + conv->base_.out_ = convolution_delegate->conv_.base_.out_; + conv->base_.out_size_ = convolution_delegate->conv_.base_.out_size_; + + conv->infershape_done_ = convolution_delegate->conv_.infershape_done_; + conv->shaing_manager_ = convolution_delegate->conv_.shaing_manager_; + conv->get_sharing_weight_ = convolution_delegate->conv_.get_sharing_weight_; + conv->free_sharing_weight_ = convolution_delegate->conv_.free_sharing_weight_; + conv->is_sharing_pack_ = convolution_delegate->conv_.is_sharing_pack_; + + conv->origin_weight_ = convolution_delegate->origin_weight_; + conv->origin_bias_ = convolution_delegate->origin_bias_; + return conv; +} + +void ConvolutionDelegateFreeCopiedData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->origin_weight_ != NULL && convolution_delegate->need_free_weight_) { + free(convolution_delegate->origin_weight_); + } + convolution_delegate->origin_weight_ = NULL; + convolution_delegate->conv_.origin_weight_ = NULL; + convolution_delegate->need_free_weight_ = false; + + if (convolution_delegate->origin_bias_ != NULL && convolution_delegate->need_free_bias_) { + free(convolution_delegate->origin_bias_); + } + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->conv_.origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; +} + +int ConvolutionDelegateResize(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + if (convolution_delegate->convolution_ == NULL) { + convolution_delegate->convolution_ = ConvolutionDelegateConvolutionSelect(convolution_delegate); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Prepare(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + } + + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Resize(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + + ConvolutionDelegateFreeCopiedData(convolution_delegate); + return NNACL_OK; +} + +int ConvolutionDelegatePrepare(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32 && + self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat16, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID); + NNACL_CHECK_FALSE(self->in_size_ == THREE_TENSOR && self->in_[THIRD_INPUT] != NULL && + self->in_[THIRD_INPUT]->data_type_ != kNumberTypeFloat32, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID); + + convolution_delegate->input_const_ = NNACLIsConst(self->in_[FIRST_INPUT]) && !self->train_session_; + convolution_delegate->weight_const_ = NNACLIsConst(self->in_[SECOND_INPUT]) && !self->train_session_; + + return ConvolutionDelegateGetWeightAndBias(convolution_delegate); +} + +int ConvolutionDelegateRelease(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + if (convolution_delegate->convolution_ != NULL) { + (void)convolution_delegate->convolution_->base_.Release(&convolution_delegate->convolution_->base_); + free(convolution_delegate->convolution_); + convolution_delegate->convolution_ = NULL; + } + if (convolution_delegate->need_free_weight_ && convolution_delegate->origin_weight_ != NULL) { + free(convolution_delegate->origin_weight_); + convolution_delegate->origin_weight_ = NULL; + } + if (convolution_delegate->need_free_bias_ && convolution_delegate->origin_bias_ != NULL) { + free(convolution_delegate->origin_bias_); + convolution_delegate->origin_bias_ = NULL; + } + return NNACL_OK; +} + +int ConvolutionDelegateCompute(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + + convolution_delegate->convolution_->base_.workspace_ = convolution_delegate->conv_.base_.workspace_; + return convolution_delegate->convolution_->base_.Compute(&convolution_delegate->convolution_->base_); +} + +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param) { + ConvolutionDelegateStruct *delegate = (ConvolutionDelegateStruct *)malloc(sizeof(ConvolutionDelegateStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(delegate); + memset(delegate, 0, sizeof(ConvolutionDelegateStruct)); + delegate->conv_.base_.Prepare = ConvolutionDelegatePrepare; + delegate->conv_.base_.Resize = ConvolutionDelegateResize; + delegate->conv_.base_.Release = ConvolutionDelegateRelease; + delegate->conv_.base_.Compute = ConvolutionDelegateCompute; + return (KernelBase *)delegate; +} + +KernelBase *CreateConvolutionDepthwise(ConvParameter *conv_param) { + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + KernelBase *kernel = NULL; + + if (conv_param->dynamic_shape_) { + kernel = CreateConvDw(conv_param); + if (kernel != NULL) { + return kernel; + } + } + +#ifdef ENABLE_AVX + kernel = CreateConvDwSWAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) + if (CheckConvDw1DWinograd(conv_param, conv_param->thread_num_)) { + kernel = CreateConvDw3x3(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + +#ifdef ENABLE_ARM64 + if (CheckConvDwUseIndirectBuffer(conv_param)) { + kernel = CreateConvDwIndirect(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + + if (conv_param->input_channel_ < MaxDwConvSWSize) { + kernel = CreateConvDwSW(conv_param); + if (kernel != NULL) { + return kernel; + } + } + + kernel = CreateConvDw(conv_param); + return kernel; +} + +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + conv_param->thread_num_ = param->thread_num_; + KernelBase *kernel; + if (conv_param->group_ == 1) { + kernel = CreateConvlutionDelegate(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + kernel = CreateConvolutionDepthwise(conv_param); + } else { + kernel = CreateGroupConvolution(conv_param, data_type); + } + + if (kernel == NULL) { + return NULL; + } + + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)kernel; + (void)ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_Conv2DFusion, kNumberTypeFloat32, CreateConv2DFusion) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h new file mode 100644 index 00000000..0f8ebd83 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_delegate.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ +#define NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/conv_parameter.h" + +typedef struct ConvolutionDelegateStruct { + ConvolutionBaseStruct conv_; /* used for Conv2dFusion basic operator */ + ConvolutionBaseStruct *convolution_; /* real running conv */ + float *origin_weight_; + float *origin_bias_; + bool need_free_weight_; + bool need_free_bias_; + bool input_const_; + bool weight_const_; +} ConvolutionDelegateStruct; + +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param); +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c new file mode 100644 index 00000000..fbcb1761 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.c @@ -0,0 +1,229 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#endif +#include "nnacl_c/fp32/conv_depthwise_avx_fp32.h" + +int ConvDwRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + +#ifdef ENABLE_AVX512 + if (X86_Avx512_Support()) { + return ConvDwAVX512(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } else { + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } +#endif + +#ifdef ENABLE_AVX + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); +#endif + + return ConvDw(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id); +} + +void ConvDwReleaseParam(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_dw->dw_param_.num_pixels_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.num_pixels_); + conv_dw->dw_param_.num_pixels_ = NULL; + } + if (conv_dw->dw_param_.out_w_start_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_start_); + conv_dw->dw_param_.out_w_start_ = NULL; + } + if (conv_dw->dw_param_.out_w_end_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.out_w_end_ = NULL; + } +} + +void ConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_data = conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_VOID(origin_data); + PackWeightKHWToHWKFp32(origin_data, conv->packed_weight_, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + + int pack_weight_size = conv->compute_.kernel_hw_ * conv->compute_.out_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + NNACL_CHECK_MALLOC_SIZE(conv->compute_.out_c_ * sizeof(float)); + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, conv->compute_.out_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, conv->compute_.out_c_ * sizeof(float)); + return NNACL_OK; +} + +int ConvDwInitConvDwCalcInfo(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + ConvDwReleaseParam(conv_dw); + + conv_dw->dw_param_.num_pixels_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + + conv_dw->dw_param_.out_w_start_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + + conv_dw->dw_param_.out_w_end_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + int *num_pixels = (int *)(conv_dw->dw_param_.num_pixels_); + int *out_w_start = (int *)(conv_dw->dw_param_.out_w_start_); + int *out_w_end = (int *)(conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.first_calc_kw_ = -1; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->dilation_w_, (compute->kernel_w_ - 1), NNACL_ERR); + for (int kw = 0; kw < compute->kernel_w_; kw++) { + out_w_start[kw] = + NNACL_MAX(0, (compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_); + + out_w_end[kw] = NNACL_MIN( + (compute->in_w_ + compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_, + compute->out_w_); + + num_pixels[kw] = out_w_end[kw] - out_w_start[kw]; + if (conv_dw->dw_param_.first_calc_kw_ == -1 && out_w_start[kw] == 0 && num_pixels[kw] == compute->out_w_) { + conv_dw->dw_param_.first_calc_kw_ = kw; + } + } + return NNACL_OK; +} + +int ConvolutionDepthwisePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + TensorC *weight_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_TRUE_RET(weight_tensor->shape_size_ == DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int weight_size_hw = NNACLGetHeight(weight_tensor) * NNACLGetWidth(weight_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetBatch(weight_tensor), weight_size_hw, NNACL_ERR); + int pack_weight_size = NNACLGetBatch(weight_tensor) * weight_size_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseCompute(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + conv_dw->input_ptr_ = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwRun, self, self->thread_nr_); +} + +int ConvolutionDepthwiseResize(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + + ret = ConvDwInitConvDwCalcInfo(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int ConvolutionDepthwiseRelease(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvDwReleaseParam(conv_dw); + + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDw(ConvParameter *conv) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)malloc(sizeof(ConvolutionDepthwiseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwMallocWeightBiasData; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwisePrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h new file mode 100644 index 00000000..96af3331 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseStruct { + ConvolutionBaseStruct conv_; + ConvDwCalcParam dw_param_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwiseStruct; + +int ConvolutionDepthwiseRelease(KernelBase *self); +KernelBase *CreateConvDw(ConvParameter *conv); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c new file mode 100644 index 00000000..199ac49b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.c @@ -0,0 +1,154 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/kernel/convolution_depthwise_3x3.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDw3x3Run(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + int c12c4_units = C12NUM * c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, task_id, NNACL_ERR); + float *buffer = conv_dw->buffer_ + c12c4_units * task_id; + NNACL_CHECK_ZERO_RETURN_ERR(conv_dw->conv_.base_.thread_nr_); + + int step_oh = UP_DIV(conv_dw->conv_.compute_.out_h_, conv_dw->conv_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(step_oh, task_id, NNACL_ERR); + int start_oh = step_oh * task_id; + int end_oh = MSMIN(start_oh + step_oh, conv_dw->conv_.compute_.out_h_); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvDw3x3(conv_dw->output_ptr_, buffer, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, start_oh, end_oh); + return NNACL_OK; +} + +void ConvDw3x3PackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackWeightConvDw3x3Fp32((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_); +} + +int ConvDw3x3MallocWeightBiasData(ConvolutionBaseStruct *conv) { + int c4 = UP_ROUND(conv->compute_.out_c_, C4NUM); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + int pack_weight_size = c4 * C12NUM; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(c4 * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, c4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, c4 * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Resize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + int ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv->compute_.out_h_); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int c4 = UP_ROUND(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c4, C12NUM, NNACL_ERR); + int pack_weight_size = c4 * C12NUM; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwise3x3Compute(KernelBase *self) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C12NUM, c4, NNACL_ERR); + int c12c4 = C12NUM * c4; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4, units, NNACL_ERR); + int c12c4_units = c12c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, self->thread_nr_, NNACL_ERR); + int buffer_size = c12c4_units * self->thread_nr_; + + conv_dw->buffer_ = self->env_->Alloc(self->env_->allocator_, buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->buffer_); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; + } + + conv_dw->input_ptr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + conv_dw->output_ptr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDw3x3Run, self, self->thread_nr_); + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; +} + +int ConvolutionDepthwise3x3Release(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param) { + ConvolutionDepthwise3x3Struct *conv_dw = + (ConvolutionDepthwise3x3Struct *)malloc(sizeof(ConvolutionDepthwise3x3Struct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwise3x3Struct)); + conv_dw->conv_.pack_weight_ = ConvDw3x3PackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDw3x3MallocWeightBiasData; + conv_dw->conv_.base_.Resize = ConvolutionDepthwise3x3Resize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwise3x3Prepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwise3x3Compute; + conv_dw->conv_.base_.Release = ConvolutionDepthwise3x3Release; + + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h new file mode 100644 index 00000000..ecbd49ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_3x3.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwise3x3Struct { + ConvolutionBaseStruct conv_; + float *buffer_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwise3x3Struct; + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c new file mode 100644 index 00000000..5b1c6d5e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise_indirect.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDwIndirectMallocIndirectBuffer(ConvolutionDepthwiseIndirectStruct *conv_dw) { + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(compute); + + conv_dw->step_w_ = compute->dilation_w_ == 1 ? compute->stride_w_ : compute->kernel_w_; + int step_w_2d = conv_dw->step_w_ * compute->kernel_h_; + conv_dw->step_h_ = (compute->kernel_h_ * compute->kernel_w_) + (compute->out_w_ - 1) * step_w_2d; + int step_h_2d = compute->out_h_ * conv_dw->step_h_; + int buffer_size = compute->out_n_ * step_h_2d; + + ExecEnv *env = conv_dw->conv_.base_.env_; + conv_dw->indirect_buffer_ = (float **)(env->Alloc(env->allocator_, buffer_size * sizeof(float *))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->indirect_buffer_); + return NNACL_OK; +} + +int ConvDwIndirectRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwIndirection(conv_dw->output_ptr_, conv_dw->indirect_buffer_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_dw->zero_ptr_, conv_param, task_id); + return NNACL_OK; +} + +int ConvDwIndirectMallocPackedInput(ConvolutionDepthwiseIndirectStruct *conv_dw) { + int IC_DIV = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, conv_dw->div_flag_ * IC_DIV, NNACL_ERR); + int pack_input_size = conv_input_bhw * conv_dw->div_flag_ * IC_DIV; + conv_dw->packed_input_ = + conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + return NNACL_OK; +} + +void ConvDwIndirectPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + PackDepthwiseIndirectWeightC8Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#else + PackDepthwiseIndirectWeightC4Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#endif +} + +int ConvDwIndirectMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)conv; + ExecEnv *env = conv->base_.env_; + + int batch_flag = UP_DIV(conv->compute_.out_c_, conv_dw->div_flag_); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + // malloc zero ptr + NNACL_CHECK_MALLOC_SIZE(batch_flag * conv_dw->div_flag_ * sizeof(float)); + conv_dw->zero_ptr_ = (float *)env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->zero_ptr_); + memset(conv_dw->zero_ptr_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + + // malloc bias ptr + if (conv->bias_data_ == NULL) { + conv->bias_data_ = env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectCompute(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + void *input_ptr = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + int ret = ConvDwIndirectMallocPackedInput(conv_dw); + if (ret != NNACL_OK) { + return ret; + } +#ifdef ENABLE_AVX + PackNHWCToNHWC8Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#else + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#endif + } else { + conv_dw->packed_input_ = input_ptr; + } + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwInitIndirection(conv_dw->indirect_buffer_, conv_dw->packed_input_, conv_dw->zero_ptr_, conv_param, + conv_dw->step_h_, conv_dw->step_w_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwIndirectRun, self, self->thread_nr_); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + self->env_->Free(self->env_->allocator_, conv_dw->packed_input_); + } + return ret; +} +int ConvolutionDepthwiseIndirectResize(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvDwIndirectMallocIndirectBuffer(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int batch_flag = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->div_flag_ * batch_flag, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseIndirectRelease(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->zero_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->zero_ptr_); + conv_dw->zero_ptr_ = NULL; + } + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param) { + ConvolutionDepthwiseIndirectStruct *conv_dw = + (ConvolutionDepthwiseIndirectStruct *)malloc(sizeof(ConvolutionDepthwiseIndirectStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseIndirectStruct)); + +#ifdef ENABLE_AVX + conv_dw->div_flag_ = C8NUM; +#else + conv_dw->div_flag_ = C4NUM; +#endif + conv_dw->conv_.pack_weight_ = ConvDwIndirectPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwIndirectMallocWeightBiasData; + + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseIndirectCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseIndirectResize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseIndirectPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseIndirectRelease; + + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h new file mode 100644 index 00000000..0008c772 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_indirect.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseIndirectStruct { + ConvolutionBaseStruct conv_; + int div_flag_; + int step_w_; + int step_h_; + float *zero_ptr_; + float *output_ptr_; + float *packed_input_; + float **indirect_buffer_; +} ConvolutionDepthwiseIndirectStruct; + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c new file mode 100644 index 00000000..1caf86f9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.c @@ -0,0 +1,200 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_depthwise_sw.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +int ConvDwSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int OC4 = UP_DIV(conv->compute_.out_c_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + int malloc_size = NNACL_MAX(conv->compute_.out_c_, C4NUM * OC4); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(malloc_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, malloc_size * sizeof(float)); + conv->base_.thread_nr_ = NNACL_MIN(conv->base_.thread_nr_, OC4); + return NNACL_OK; +} + +int ConvDwSWInitPackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->conv_.compute_.in_c_ % C4NUM == 0) { + conv_dw->need_align_ = false; + return NNACL_OK; + } + + conv_dw->need_align_ = true; + int IC4 = UP_DIV(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, C4NUM * IC4, NNACL_ERR); + int pack_input_size = conv_input_bhw * C4NUM * IC4; + NNACL_CHECK_MALLOC_SIZE(pack_input_size * sizeof(float)); + conv_dw->packed_input_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, C4NUM * OC4, NNACL_ERR); + int pack_output_size = output_bhw * C4NUM * OC4; + NNACL_CHECK_MALLOC_SIZE(pack_output_size * sizeof(float)); + conv_dw->packed_output_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + return NNACL_OK; +} + +int ConvDwSWRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwSWFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_, task_id); + return NNACL_OK; +} + +void ConvDwSWFreePackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + } +} + +void ConvDwSWPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvolutionDepthwiseSWResize(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + InitSlidingParamConvDw(&conv_dw->sliding_, conv_param, C4NUM); + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWCompute(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + if (conv_dw->need_align_) { + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + if (!conv_dw->need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWRun, self, self->thread_nr_); + + if (conv_dw->need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, C4NUM); + } + + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDdepthwiseSWPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM * OC4, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = C4NUM * OC4 * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSW(ConvParameter *conv_param) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)malloc(sizeof(ConvolutionDepthwiseSWStruct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWStruct)); + + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWMallocWeightBiasData; + conv_dw->conv_.pack_weight_ = ConvDwSWPackWeight; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWResize; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWCompute; + conv_dw->conv_.base_.Prepare = ConvolutionDdepthwiseSWPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h new file mode 100644 index 00000000..a7d6819e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + float *packed_input_; + float *packed_output_; + bool need_align_; +} ConvolutionDepthwiseSWStruct; + +KernelBase *CreateConvDwSW(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c new file mode 100644 index 00000000..84c830bc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.c @@ -0,0 +1,216 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_depthwise_sw_avx.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +int ConvDwSWAVXInitPackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + conv_dw->input_need_align_ = (conv_dw->conv_.compute_.in_c_ % conv_dw->oc_tile_ != 0); + conv_dw->output_need_align_ = (conv_dw->conv_.compute_.out_c_ % conv_dw->oc_tile_ != 0); + + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_dw->input_need_align_) { + int ic_algin = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, conv_dw->oc_tile_ * ic_algin, NNACL_ERR); + int pack_input_size = input_bhw * conv_dw->oc_tile_ * ic_algin; + conv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + } + + if (conv_dw->output_need_align_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_dw->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_dw->oc_tile_ * oc_algin; + conv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + } + + return NNACL_OK; +} + +void ConvDwSWAVXPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + PackNHWCToNXHWCXFp32(conv->compute_.kernel_h_, conv->compute_.kernel_w_, conv->compute_.out_c_, oc_algin, 1, + (float *)conv->packed_weight_, (float *)conv->origin_weight_); +} + +int ConvDwSWAVXMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv->compute_.kernel_hw_; + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + int bias_size = oc_algin * conv_dw->oc_tile_; + NNACL_CHECK_MALLOC_SIZE(bias_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, bias_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, bias_size * sizeof(float)); + } + return NNACL_OK; +} + +int ConvDwSWAvxRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + DepthwiseSWAvxFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_param_, task_id); + return NNACL_OK; +} + +void ConvDwSWAVXFreePackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + if (conv_dw->input_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->input_need_align_ = false; + } + if (conv_dw->output_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + conv_dw->output_need_align_ = false; + } +} + +int ConvolutionDepthwiseSWAVXCompute(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWAVXInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->input_need_align_) { + PackNHWCToNHWCXFp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (!conv_dw->output_need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWAvxRun, self, self->thread_nr_); + + if (conv_dw->output_need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + } + + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDepthwiseSWAVXPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + conv_dw->oc_tile_ = C8NUM; + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_algin * conv_dw->oc_tile_, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWAVXResize(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBasePrepare(&conv_dw->conv_); + + InitSlidingParamConvDw(&conv_dw->sliding_param_, conv_param, conv_dw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWAVXRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = + (ConvolutionDepthwiseSWAVXStruct *)malloc(sizeof(ConvolutionDepthwiseSWAVXStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWAVXStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwSWAVXPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWAVXMallocWeightBiasData; + + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseSWAVXPrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWAVXCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWAVXResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWAVXRelease; + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h new file mode 100644 index 00000000..8a76ccbd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_depthwise_sw_avx.h @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseSWAVXStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_param_; + int oc_tile_; + float *packed_input_; + float *packed_output_; + bool input_need_align_; + bool output_need_align_; +} ConvolutionDepthwiseSWAVXStruct; + +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c new file mode 100644 index 00000000..482333cd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_im2col.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_im2col_arm32.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_im2col_arm64.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_im2col_sse.h" +#endif +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_im2col_avx.h" +#endif +#ifdef ENABLE_AVX512 +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/kernel/convolution_im2col_avx512.h" +#endif + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param) { + ConvolutionBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX512 + FormatC out_format = base->out_[OUTPUT_INDEX]->format_; + if (out_format != Format_NC4HW4) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + kernel = CreateConvIm2ColAVX512(conv_param); + if (kernel != NULL) { + return kernel; + } + AVX512_HARDWARE_SELF_AWARENESS_END; + } +#endif + +#ifdef ENABLE_AVX + kernel = CreateConvIm2ColAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvIm2ColSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvIm2ColARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvIm2ColARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvIm2ColBase(conv_param); + return kernel; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h new file mode 100644 index 00000000..ab115d53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c new file mode 100644 index 00000000..44f69f65 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_im2col_arm32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void ConvIm2ColARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C4NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col4Major; +} + +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM32InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h new file mode 100644 index 00000000..c928273e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c new file mode 100644 index 00000000..96bd0c94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.c @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_im2col_arm64.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM64InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColARM64RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h new file mode 100644 index 00000000..92288bc7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_arm64.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c new file mode 100644 index 00000000..2d1efc03 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.c @@ -0,0 +1,151 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_im2col_avx.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +void ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = C6NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col16Major; +} + +int ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + int kernel_chw = conv_im2col->conv_.compute_.kernel_hw_ * conv_im2col->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + conv_im2col->output_need_align_ = + conv_im2col->conv_.compute_.out_c_ % conv_im2col->oc_tile_ != 0 && conv_im2col->conv_.out_format_ == Format_NC4HW4; + if (conv_im2col->output_need_align_) { + int oc_algin = UP_DIV(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int output_bhw = conv_im2col->conv_.compute_.out_n_ * conv_im2col->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_im2col->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_im2col->oc_tile_ * oc_algin; + + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + return NNACL_OK; +} + +int ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvxCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNC8HW8AlignedToNC8HW8NotAlignedFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_w_ * conv_im2col->conv_.compute_.out_h_, + conv_im2col->conv_.compute_.out_c_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVXInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVXInitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVXRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvxCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h new file mode 100644 index 00000000..48e51e66 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c new file mode 100644 index 00000000..9b4a65ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.c @@ -0,0 +1,146 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/kernel/convolution_im2col_avx512.h" +#include "nnacl_c/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +void ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = + MSMIN(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.base_.thread_nr_), C150NUM); + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col64Major; +} + +int ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_im2col->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_hw_, compute->in_c_, NNACL_ERR); + int kernel_chw = compute->kernel_hw_ * compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + size_t unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + conv_im2col->output_need_align_ = compute->out_c_ % conv_im2col->oc_tile_ != 0; + if (conv_im2col->output_need_align_) { + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + + // avx512 need to malloc dst aligned to C16NUM + int oc_algin = UP_ROUND(compute->out_c_, conv_im2col->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_algin, NNACL_ERR); + size_t pack_output_size = output_bhw * compute->out_w_ * oc_algin; + + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + + return NNACL_OK; +} + +int ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + if (conv->out_format_ == Format_NC4HW4) { + return NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT; + } + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->use_batch_cut_flag_) { + ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } else { + ConvIm2ColAVX512Fp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvx512Compute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNHWCXToNHWCFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVX512InitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVX512InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVX512RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvx512Compute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h new file mode 100644 index 00000000..8fde7b9f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_avx512.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c new file mode 100644 index 00000000..36a9846e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.c @@ -0,0 +1,246 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_im2col_base.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/conv_common_fp32.h" + +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + return conv->run_impl_(conv, task_id); +} + +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + return NNACL_OK; +} + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + size_t oc_block_num = UP_ROUND(conv->compute_.out_c_, conv_im2col->oc_tile_); + size_t pack_weight_size = oc_block_num * conv->compute_.in_c_ * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc_block_num * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc_block_num * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, oc_block_num * sizeof(float)); + return NNACL_OK; +} + +int ConvIm2ColBaseUpdateThreadNumProcess(KernelBase *self, int32_t kernel_type, int64_t per_unit_load_num, + int64_t per_unit_store_num, int64_t unit_num) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + if (conv_im2col->conv_.compute_.in_n_ % self->thread_nr_ == 0) { + conv_im2col->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + conv_im2col->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->row_tile_), ConvMinBlock); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, update_thread); +#else + self->thread_nr_ = self->thread_nr_ > 0 ? self->thread_nr_ : 1; +#endif + return NNACL_OK; +} + +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + if (conv_im2col->output_need_align_ && conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + conv_im2col->output_need_align_ = false; + } +} + +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)conv_im2col; + TensorC *out_tensor = conv_im2col->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv->compute_.kernel_hw_, conv->compute_.in_c_, NNACL_ERR); + int kernel_chw = conv->compute_.kernel_hw_ * conv->compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv->base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + return NNACL_OK; +} + +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_im2col->row_major_to_col_nmajor_); + conv_im2col->row_major_to_col_nmajor_((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_, + conv->compute_.in_c_ * conv->compute_.kernel_hw_); +} + +void ConvIm2ColBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvolutionIm2colBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionIm2colBaseCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +int ConvolutionIm2colBaseResize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + + int ret = ConvBaseCheckResizeValid(conv); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + + return ConvIm2ColBaseUpdateThreadNumProcess(self, TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0); +} + +int ConvolutionIm2colBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + conv_im2col->conv_.init_global_variable_(&conv_im2col->conv_); + + if (self->train_session_) { + int oc_block_num = UP_ROUND(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int kernel_chw = conv_im2col->conv_.compute_.in_c_ * conv_im2col->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * kernel_chw; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_im2col->conv_); +} + +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColBaseInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h new file mode 100644 index 00000000..ce8ec9e1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_base.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionIm2ColBaseStruct { + ConvolutionBaseStruct conv_; + int oc_tile_; + int row_tile_; + + float *tmp_output_; + float *packed_input_; + float *col_major_input_; + bool output_need_align_; + + void (*row_major_to_col_nmajor_)(const float *src_ptr, float *dst_ptr, int row, int col); + int (*init_tmp_buffer_)(struct ConvolutionIm2ColBaseStruct *conv_im2col); +} ConvolutionIm2ColBaseStruct; + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r); +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvolutionIm2colBaseCompute(KernelBase *self); +int ConvolutionIm2colBasePrepare(KernelBase *self); +int ConvolutionIm2colBaseResize(KernelBase *self); +int ConvolutionIm2colBaseRelease(KernelBase *self); +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c new file mode 100644 index 00000000..c08d3b09 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.c @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_im2col_sse.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" + +void ConvIm2ColSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C4NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColSSEInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h new file mode 100644 index 00000000..9762eee9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_im2col_sse.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c new file mode 100644 index 00000000..c031575b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" + +int ConvSWInitTmpBuffer(ConvolutionSWStruct *conv_sw) { + TensorC *input_tensor = conv_sw->conv_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + ConvComputeParam *compute = &conv_sw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + if (conv_sw->ic_res_ != 0 && compute->kernel_h_ == 1 && compute->kernel_w_ == 1) { + int ic_block_num = UP_DIV(compute->in_c_, conv_sw->in_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * conv_sw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic_block_num * conv_sw->in_tile_, NNACL_ERR); + + conv_sw->input_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, input_bhw * ic_block_num * conv_sw->in_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->input_data_); + + PackNHWCToNHWCXFp32(input_data, conv_sw->input_data_, compute->in_n_, compute->in_hw_, compute->in_c_, + conv_sw->oc_tile_); + } else { + conv_sw->input_data_ = input_data; + } + + float *out_data = (float *)conv_sw->conv_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + if (conv_sw->oc_res_ == 0) { // not need to malloc dst + conv_sw->output_data_ = out_data; + } else { // need to malloc dst to align block + int oc_block_num = UP_DIV(compute->out_c_, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_block_num * conv_sw->oc_tile_, NNACL_ERR); + conv_sw->output_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, output_bhw * oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->output_data_); + } + + return NNACL_OK; +} + +void ConvSWFreeTmpBuffer(ConvolutionSWStruct *conv_sw) { + ConvParameter *conv_param = (ConvParameter *)conv_sw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + if (conv_sw->output_data_ != NULL && conv_sw->oc_res_ != 0) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->output_data_); + conv_sw->output_data_ = NULL; + } + if (conv_sw->input_data_ != NULL && conv_sw->ic_res_ != 0 && conv_param->kernel_w_ == 1 && + conv_param->kernel_h_ == 1) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->input_data_); + conv_sw->input_data_ = NULL; + } +} + +void ConvSWPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + void *origin_weight = (conv->base_.train_session_) ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNHWCToNXHWCXFp32(kernel_h, kernel_w, output_channel, oc_block_num, input_channel, (float *)conv->packed_weight_, + (float *)origin_weight); +} + +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + conv_param->input_channel_ = input_channel; + conv_param->output_channel_ = output_channel; + int kernel_plane = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * input_channel * kernel_plane; + if (!conv_sw->conv_.base_.train_session_) { + conv_sw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->conv_.packed_weight_); + } + + if (conv_sw->conv_.base_.in_size_ == THREE_TENSOR) { + int malloc_size = oc_block_num * conv_sw->oc_tile_ * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + } + return NNACL_OK; +} + +int ConvSWImpl(void *cdata, int task_id, float l, float r) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + return conv_sw->conv_.run_impl_(&conv_sw->conv_, task_id); +} + +int ConvolutionSWCompute(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + int ret = ConvSWInitTmpBuffer(conv_sw); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_sw->conv_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvSWImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + if (conv_sw->oc_res_ != 0) { + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *out_data = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + PackNHWCXToNHWCFp32(conv_sw->output_data_, out_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_, conv_sw->oc_tile_); + } + + ConvSWFreeTmpBuffer(conv_sw); + return NNACL_OK; +} + +int ConvolutionSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionSWResize(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + int ret = ConvBaseCheckResizeValid(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + InitSlidingParamConv(&conv_sw->sw_param_, conv_param, conv_sw->in_tile_, conv_sw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionSWPrepare(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + conv_sw->conv_.init_global_variable_(&conv_sw->conv_); + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_h, kernel_w, NNACL_ERR); + int kernel_hw = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_channel, kernel_hw, NNACL_ERR); + int kernel_chw = input_channel * kernel_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num * conv_sw->oc_tile_, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * kernel_chw; + + conv_sw->conv_.base_.work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_sw->conv_); +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h new file mode 100644 index 00000000..a888b9f2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_slidewindow.h @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ +#define NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/matmul_parameter.h" + +typedef struct ConvolutionSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sw_param_; + int oc_tile_; + int in_tile_; + int oc_res_; + int ic_res_; + float *output_data_; + float *input_data_; +} ConvolutionSWStruct; + +int ConvolutionSWPrepare(KernelBase *self); +int ConvolutionSWCompute(KernelBase *self); +int ConvolutionSWResize(KernelBase *self); +int ConvolutionSWRelease(KernelBase *self); +void ConvSWPackWeight(ConvolutionBaseStruct *conv); +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c new file mode 100644 index 00000000..c0aa92ed --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.c @@ -0,0 +1,152 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_1x1.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +int MatmulConv1x1Prelare(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->batch_ = 1; + sw_1x1->matmul_->a_batch_ = 1; + sw_1x1->matmul_->b_batch_ = 1; + + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + return sw_1x1->matmul_->base_.Prepare(&sw_1x1->matmul_->base_); +} + +int MatmulConv1x1Resize(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + MatmulBaseFreeBatchOffset(sw_1x1->matmul_); + int ret = MatmulBaseMallocBatchOffset(sw_1x1->matmul_); + if (ret != NNACL_OK) { + return ret; + } + + return sw_1x1->matmul_->base_.Resize(&sw_1x1->matmul_->base_); +} + +void UpdateTensorInfo(KernelBase *self, ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->base_.in_ = self->in_; + sw_1x1->matmul_->base_.in_size_ = self->in_size_; + sw_1x1->matmul_->base_.out_ = self->out_; + sw_1x1->matmul_->base_.out_size_ = self->out_size_; + sw_1x1->matmul_->base_.workspace_ = self->workspace_; +} + +int ConvolutionSW1x1Compute(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return sw_1x1->matmul_->base_.Compute(&sw_1x1->matmul_->base_); +} + +int ConvolutionSW1x1Resize(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Resize(sw_1x1); +} + +int ConvolutionSW1x1Prepare(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + sw_1x1->matmul_->matrix_b_.origin_ptr_ = sw_1x1->conv_.origin_weight_; + sw_1x1->matmul_->matrix_b_.origin_need_free_ = false; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = sw_1x1->conv_.origin_bias_; + sw_1x1->matmul_->matrix_c_.origin_need_free_ = false; + + sw_1x1->matmul_->infer_shape_ = sw_1x1->conv_.infershape_done_; + sw_1x1->matmul_->base_.train_session_ = self->train_session_; + sw_1x1->matmul_->base_.thread_nr_ = self->thread_nr_; + sw_1x1->matmul_->base_.env_ = self->env_; + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Prelare(sw_1x1); +} + +int ConvolutionSW1x1Release(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + + if (sw_1x1->matmul_ != NULL) { + sw_1x1->matmul_->matrix_b_.origin_ptr_ = NULL; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = NULL; + + (void)sw_1x1->matmul_->base_.Release(&sw_1x1->matmul_->base_); + + if (sw_1x1->matmul_->base_.param_ != NULL) { + free(sw_1x1->matmul_->base_.param_); + sw_1x1->matmul_->base_.param_ = NULL; + } + + free(sw_1x1->matmul_); + sw_1x1->matmul_ = NULL; + } + + ConvBaseRelease(&sw_1x1->conv_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)malloc(sizeof(ConvolutionSW1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw_1x1); + memset(sw_1x1, 0, sizeof(ConvolutionSW1x1Struct)); + + sw_1x1->conv_.is_sharing_pack_ = false; + sw_1x1->conv_.base_.Compute = ConvolutionSW1x1Compute; + sw_1x1->conv_.base_.Resize = ConvolutionSW1x1Resize; + sw_1x1->conv_.base_.Prepare = ConvolutionSW1x1Prepare; + sw_1x1->conv_.base_.Release = ConvolutionSW1x1Release; + + OpParameter *param = (OpParameter *)malloc(sizeof(MatMulParameter)); + if (param == NULL) { + free(sw_1x1); + return NULL; + } + MatMulParameter *matmul_param = (MatMulParameter *)param; + matmul_param->op_parameter_ = conv_param->op_parameter_; + matmul_param->act_type_ = conv_param->act_type_; + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + + KernelBase *matmul = CreateMatmulKernel(); + if (matmul == NULL) { + free(sw_1x1); + free(param); + return NULL; + } + + ((MatmulStruct *)matmul)->is_sharing_pack_ = false; + ((MatmulStruct *)matmul)->a_const_ = input_const; + ((MatmulStruct *)matmul)->b_const_ = weight_const; + ((MatmulStruct *)matmul)->base_.param_ = param; + sw_1x1->matmul_ = (MatmulStruct *)matmul; + return (ConvolutionBaseStruct *)sw_1x1; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h new file mode 100644 index 00000000..46804c5d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_1x1.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/kernel/matmul_struct.h" + +typedef struct ConvolutionSW1x1Struct { + ConvolutionBaseStruct conv_; + MatmulStruct *matmul_; +} ConvolutionSW1x1Struct; + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c new file mode 100644 index 00000000..4a919410 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.c @@ -0,0 +1,59 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_sw_arm64.h" +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_sw_arm64_fp32.h" + +void ConvSWARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; +} + +int ConvSWARM64RunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvSWArm64Fp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWARM64RunImpl; + sw->conv_.init_global_variable_ = ConvSWARM64InitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h new file mode 100644 index 00000000..e546924c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c new file mode 100644 index 00000000..fbcadfca --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.c @@ -0,0 +1,71 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_sw_avx.h" +#include "nnacl_c/kernel/convolution_slidewindow.h" +#include "nnacl_c/fp32/conv_1x1_avx_fp32.h" +#include "nnacl_c/fp32/conv_sw_avx_fp32.h" + +void ConvSWAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + // 1x1 conv is aligned to C8NUM + conv_sw->in_tile_ = C8NUM; + conv_sw->ic_res_ = conv_param->input_channel_ % conv_sw->in_tile_; + } +} + +int ConvSWAVXRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv_param->kernel_w_ == 1 && conv_param->kernel_h_ == 1) { + Conv1x1SWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_sw->output_data_, task_id, conv_param, &conv_sw->sw_param_); + } else { + ConvSWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWAVXRunImpl; + sw->conv_.init_global_variable_ = ConvSWAVXInitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h new file mode 100644 index 00000000..c2d47268 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_sw_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c new file mode 100644 index 00000000..d7e464dd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_winograd.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_winograd_avx.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_winograd_sse.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_winograd_arm64.h" +#endif +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_winograd_arm32.h" +#endif + +ConvolutionWinogradBaseStruct *SelectConvolutionWinograd(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX + kernel = CreateConvWinogradAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvWinogradSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvWinogradARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvWinogradARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvWinogradBase(conv_param); + return kernel; +} + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_unit) { + ConvolutionWinogradBaseStruct *kernel = SelectConvolutionWinograd(conv_param); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + + kernel->output_unit_ = out_unit; + kernel->conv_.malloc_weight_bias_ = ConvWinoBaseMallocWeightBiasData; + kernel->conv_.run_impl_ = ConvWinoBaseRunImpl; + kernel->conv_.pack_weight_ = ConvWinoBasePackWeight; + return (ConvolutionBaseStruct *)kernel; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h new file mode 100644 index 00000000..23c5e1c8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct ConvolutionWinogradStruct { + ConvolutionBaseStruct conv_; +} ConvolutionWinogradStruct; + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_uint); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c new file mode 100644 index 00000000..f22088a8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/convolution_winograd_arm32.h" + +void ConvWinoARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM32InitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h new file mode 100644 index 00000000..21b32c4a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c new file mode 100644 index 00000000..0594f2da --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.c @@ -0,0 +1,60 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/convolution_winograd_arm64.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoARM64ConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_step_func_ = GetInputTransStepFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_pack_func_ = GetInputTransPackFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ActType act_type = ((ConvParameter *)winograd->conv_.base_.param_)->act_type_; + winograd->transfer_functions_.out_func_ = GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, act_type); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoARM64ConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM64InitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h new file mode 100644 index 00000000..d2e98d4f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_arm64.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c new file mode 100644 index 00000000..30242ae4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.c @@ -0,0 +1,43 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/convolution_winograd_avx.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C16NUM; + winograd->tmp_data_tile_ = C8NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoAVXInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h new file mode 100644 index 00000000..baa83182 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_avx.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c new file mode 100644 index 00000000..f71e7fce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.c @@ -0,0 +1,320 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/convolution_winograd_base.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/fp32/winograd_transform.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" + +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd); + + // set data + size_t trans_matrix_data_size = winograd->input_unit_ * winograd->input_unit_ * conv->compute_.in_c_ * + UP_ROUND(conv->compute_.out_c_, winograd->oc_block_) * sizeof(float); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(trans_matrix_data_size); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, trans_matrix_data_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + float matrix_a[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_at[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_b[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_bt[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float coef = 1.0f; + if (winograd->input_unit_ == CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE) { + coef = 0.5f; + } + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, winograd->matrix_g_, winograd->matrix_gt_, coef, + winograd->output_unit_, winograd->kernel_unit_); + if (ret != NNACL_OK) { + return ret; + } + + // init bias + size_t new_bias_size = UP_ROUND(conv->compute_.out_c_, C4NUM) * sizeof(float); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(new_bias_size); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, new_bias_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, new_bias_size); + return NNACL_OK; +} + +void ConvWinoBaseFreeTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (winograd->trans_input_ != NULL) { + env->Free(env->allocator_, winograd->trans_input_); + winograd->trans_input_ = NULL; + } + if (winograd->tmp_data_ != NULL) { + env->Free(env->allocator_, winograd->tmp_data_); + winograd->tmp_data_ = NULL; + } + if (winograd->gemm_out_ != NULL) { + env->Free(env->allocator_, winograd->gemm_out_); + winograd->gemm_out_ = NULL; + } + if (winograd->col_buffer_ != NULL) { + env->Free(env->allocator_, winograd->col_buffer_); + winograd->col_buffer_ = NULL; + } + if (winograd->opt_input_trans_ != NULL) { + env->Free(env->allocator_, winograd->opt_input_trans_); + winograd->opt_input_trans_ = NULL; + } +} + +void ConvWinoBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoBaseWinogradFilterTransform(ConvolutionWinogradBaseStruct *winograd, const float *weight_data) { + NNACL_CHECK_ZERO_RETURN_ERR(winograd->oc_block_); + return WinogradWeightTransform(weight_data, (float *)winograd->conv_.packed_weight_, winograd->matrix_g_, + winograd->matrix_gt_, winograd->oc_block_, winograd->input_unit_, + winograd->kernel_unit_, winograd->conv_.compute_.in_c_, + winograd->conv_.compute_.out_c_, true); +} + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(winograd); + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + ConvWinoBaseWinogradFilterTransform(winograd, (float *)origin_weight); +} + +int ConvolutionWinogradBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + winograd->conv_.init_global_variable_(&winograd->conv_); + + winograd->kernel_unit_ = winograd->conv_.compute_.kernel_h_; + winograd->input_unit_ = winograd->output_unit_ + winograd->kernel_unit_ - 1; + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_plane = winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_plane, winograd->conv_.compute_.in_c_, NNACL_ERR); + int in_chw = input_plane * winograd->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(in_chw, UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_), NNACL_ERR); + int trans_matrix_data_size = + in_chw * UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_) * sizeof(float); + self->work_size_ = trans_matrix_data_size; + } + + return ConvBaseInitConvWeightBias(&winograd->conv_); +} + +int ConvoWinoBaseUpdateThreadNumProcess(ConvolutionWinogradBaseStruct *winograd) { + if (winograd->conv_.compute_.in_n_ % winograd->conv_.base_.thread_nr_ == 0) { + winograd->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + winograd->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(winograd->conv_.compute_.out_hw_, C12NUM), ConvMinBlock); + winograd->conv_.base_.thread_nr_ = NNACL_MIN(update_thread, winograd->conv_.base_.thread_nr_); + return NNACL_OK; +} + +int ConvoWinoBaseUpdateThread(ConvolutionWinogradBaseStruct *winograd) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvoWinoBaseUpdateThreadNumProcess(winograd); +#else + KernelBase *base = &winograd->conv_.base_; + base->thread_nr_ = base->UpdateThread(TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0, base->thread_nr_); +#endif + return NNACL_OK; +} + +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ConvParameter *conv_param = (ConvParameter *)winograd->conv_.base_.param_; + winograd->transfer_functions_.out_func_ = + GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, conv_param->act_type_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.out_func_); + + return NNACL_OK; +} + +int ConvoWinoBaseInitTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int thread_input_plane = winograd->conv_.base_.thread_nr_ * winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(winograd->tile_num_, thread_input_plane, NNACL_ERR); + int total_thread_input_plane = winograd->tile_num_ * thread_input_plane; + size_t tile_buffer_size = total_thread_input_plane * winograd->conv_.compute_.in_c_ * sizeof(float); + winograd->trans_input_ = (float *)env->Alloc(env->allocator_, tile_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->trans_input_); + + int oc8 = UP_ROUND(winograd->conv_.compute_.out_c_, C8NUM); + winograd->gemm_out_ = env->Alloc(env->allocator_, total_thread_input_plane * oc8 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->gemm_out_); + + winograd->tmp_data_ = env->Alloc(env->allocator_, winograd->tmp_data_tile_ * thread_input_plane * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->tmp_data_); + + winograd->col_buffer_ = env->Alloc(env->allocator_, winograd->conv_.base_.thread_nr_ * winograd->tile_num_ * + winograd->conv_.compute_.in_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->col_buffer_); + + int tile = UP_ROUND(winograd->conv_.compute_.in_c_, winograd->tmp_data_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thread_input_plane, tile, NNACL_ERR); + winograd->opt_input_trans_ = env->Alloc(env->allocator_, total_thread_input_plane * tile * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->opt_input_trans_); + + winograd->tmp_buffer_address_list_[Index0] = winograd->trans_input_; + winograd->tmp_buffer_address_list_[Index1] = winograd->gemm_out_; + winograd->tmp_buffer_address_list_[Index2] = winograd->tmp_data_; + winograd->tmp_buffer_address_list_[Index3] = winograd->col_buffer_; + winograd->tmp_buffer_address_list_[Index4] = winograd->opt_input_trans_; + return NNACL_OK; +} + +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + + TensorC *output_tensor = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (conv->use_batch_cut_flag_) { + ConvWinogardFp32CutByBatch(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } else { + ConvWinogardFp32(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } + + return NNACL_OK; +} + +int ConvWinoImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv); + return conv->run_impl_(conv, task_id); +} + +void ConvWinoBaseUpdateParam(ConvParameter *param, ConvolutionWinogradBaseStruct *winograd) { + param->input_unit_ = winograd->input_unit_; + param->output_unit_ = winograd->output_unit_; +} + +int ConvolutionWinogradBaseResize(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvBaseCheckResizeValid(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvoWinoBaseUpdateThread(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ret = winograd->config_input_output_(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ConvWinoBaseUpdateParam((ConvParameter *)self->param_, winograd); + return NNACL_OK; +} + +int ConvolutionWinogradBaseCompute(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvoWinoBaseInitTmpBuffer(winograd); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = ConvBaseRepackWeight(&winograd->conv_); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvWinoImpl, self, self->thread_nr_); + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; +} + +int ConvolutionWinogradBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoBaseInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return (ConvolutionWinogradBaseStruct *)winograd; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h new file mode 100644 index 00000000..85ffca6b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_base.h @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/fp32/winograd_utils.h" + +#define CONVOLUTION_WINOGRAD_MATRIX_SIZE 64 +#define CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE 5 +#define CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE 8 + +typedef float *TmpBufferAddress; + +typedef struct ConvolutionWinogradBaseStruct { + ConvolutionBaseStruct conv_; + + int kernel_unit_; + int input_unit_; + int output_unit_; + int oc_block_; + int tile_num_; + int tmp_data_tile_; + float *tmp_data_; + float *trans_input_; + float *gemm_out_; + float *col_buffer_; + float *opt_input_trans_; + float matrix_g_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_gt_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + TmpBufferAddress tmp_buffer_address_list_[CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE]; + TransFuncList transfer_functions_; + + int (*config_input_output_)(struct ConvolutionWinogradBaseStruct *winograd); +} ConvolutionWinogradBaseStruct; + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv); +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd); +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvolutionWinogradBasePrepare(KernelBase *self); +int ConvolutionWinogradBaseResize(KernelBase *self); +int ConvolutionWinogradBaseRelease(KernelBase *self); +int ConvolutionWinogradBaseCompute(KernelBase *self); +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c new file mode 100644 index 00000000..d91dbc50 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/convolution_winograd_sse.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +void ConvWinoSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoSSEInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + + return winograd; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h new file mode 100644 index 00000000..82755a52 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/convolution_winograd_sse.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c new file mode 100644 index 00000000..16244cb8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/crop.h" +#include "nnacl_c/base/crop_base.h" +#include "nnacl_c/fp32/crop_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/crop_fp16.h" +#endif + +int CropLaunch(void *cdata, int task_id, float l, float r) { + CropStruct *crop = (CropStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop); + + TensorC *in = crop->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + TensorC *out = crop->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + +#ifdef ENABLE_FP16 + if (in->data_type_ == kNumberTypeFloat16) { + Fp16Crop((float16_t *)in->data_, (float16_t *)out->data_, in->shape_, out->shape_, crop->in_offset_, + in->shape_size_, task_id, crop->base_.thread_nr_); + return NNACL_OK; + } +#endif + + CropParameter *crop_param = (CropParameter *)crop->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + Crop4D((float *)in->data_, (float *)out->data_, in->shape_, out->shape_, crop_param, task_id, crop->base_.thread_nr_); + return NNACL_OK; +} + +int CropResize(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_FALSE(out_tensor->shape_size_ <= Num1, NNACL_OUTPUT_TENSOR_ERROR); + + CropStruct *crop = (CropStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + return CropPadOffset(in_tensor->shape_size_, crop_param, crop->in_offset_); +} + +int CropCompute(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + if (in_tensor->data_type_ != kNumberTypeFloat16 && out_tensor->shape_[Index1] < self->thread_nr_) { + float *input_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + float *output_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + Crop4DNoParallel(input_data, output_data, in_tensor->shape_, out_tensor->shape_, crop_param); + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CropLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCrop(OpParameter *param, int data_type) { + CropStruct *crop = (CropStruct *)malloc(sizeof(CropStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop); + memset(crop, 0, sizeof(CropStruct)); + crop->base_.Prepare = DefaultPrepare1In1Out; + crop->base_.Resize = CropResize; + crop->base_.Release = DefaultRelease; + crop->base_.Compute = CropCompute; + return (KernelBase *)crop; +} + +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeInt32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat16, CreateCrop) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h new file mode 100644 index 00000000..26408dd7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_H_ +#define NNACL_KERNEL_CROP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct CropStruct { + KernelBase base_; + int64_t in_offset_[COMM_SHAPE_SIZE]; +} CropStruct; + +KernelBase *CreateCrop(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c new file mode 100644 index 00000000..0c0054b3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/crop_and_resize.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/resize_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int CropAndResizeMallocTmpBuffer(CropAndResizeStruct *crop_and_resize) { + TensorC *input_tensor = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + // Malloc buffer to save coordinate. + // For mode CROP_AND_RESIZE, different output batches require different cache coordinates. + crop_and_resize->batch_ = NNACLGetBatch(output_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_height_, crop_and_resize->batch_, NNACL_ERR); + int height_size = crop_and_resize->new_height_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(height_size); + crop_and_resize->y_bottoms_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottoms_); + crop_and_resize->y_tops_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_tops_); + crop_and_resize->y_bottom_weights_ = (float *)env->Alloc(env->allocator_, height_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottom_weights_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, crop_and_resize->batch_, NNACL_ERR); + int width_size = crop_and_resize->new_width_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(width_size); + crop_and_resize->x_lefts_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_lefts_); + crop_and_resize->x_rights_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_rights_); + crop_and_resize->x_left_weights_ = (float *)env->Alloc(env->allocator_, width_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_left_weights_); + + int c = NNACLGetChannel(input_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, c, NNACL_ERR); + int new_wc = crop_and_resize->new_width_ * c; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(new_wc, crop_and_resize->mapped_point_num_, NNACL_ERR); + int total_point_num = new_wc * crop_and_resize->mapped_point_num_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_point_num, crop_and_resize->base_.thread_nr_, NNACL_ERR); + int line_buffer_size = total_point_num * crop_and_resize->base_.thread_nr_ * sizeof(float); + crop_and_resize->line_buffer_ = (float *)env->Alloc(env->allocator_, line_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->line_buffer_); + return NNACL_OK; +} + +void CropAndResizeFreeTmpBuffer(CropAndResizeStruct *crop_and_resize) { + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + env->Free(env->allocator_, crop_and_resize->y_bottoms_); + env->Free(env->allocator_, crop_and_resize->y_tops_); + env->Free(env->allocator_, crop_and_resize->y_bottom_weights_); + env->Free(env->allocator_, crop_and_resize->x_lefts_); + env->Free(env->allocator_, crop_and_resize->x_rights_); + env->Free(env->allocator_, crop_and_resize->x_left_weights_); + env->Free(env->allocator_, crop_and_resize->line_buffer_); + crop_and_resize->y_bottoms_ = NULL; + crop_and_resize->y_tops_ = NULL; + crop_and_resize->y_bottom_weights_ = NULL; + crop_and_resize->x_lefts_ = NULL; + crop_and_resize->x_rights_ = NULL; + crop_and_resize->x_left_weights_ = NULL; + crop_and_resize->line_buffer_ = NULL; +} + +int CropAndResizeImpl(void *cdata, int task_id, float l, float r) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + TensorC *input = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *boxes = crop_and_resize->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + TensorC *box_idx = crop_and_resize->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + TensorC *output = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int unit = UP_DIV(crop_and_resize->new_height_, crop_and_resize->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(unit, task_id, NNACL_ERR); + int h_begin = unit * task_id; + int h_end = MSMIN(h_begin + unit, crop_and_resize->new_height_); + if (h_end <= h_begin) { + return NNACL_OK; + } + + float extrapolation_value = ((CropAndResizeParameter *)crop_and_resize->base_.param_)->extrapolation_value_; + int c = input->shape_[kNHWC_C]; + float *line0 = crop_and_resize->line_buffer_ + crop_and_resize->new_width_ * c * 2 * task_id; + float *line1 = line0 + crop_and_resize->new_width_ * c; + + return CropAndResizeBilinear((float *)input->data_, (float *)output->data_, (int32_t *)box_idx->data_, + (float *)boxes->data_, extrapolation_value, input->shape_, output->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_, line0, line1, h_begin, h_end); +} + +int CropAndResizeCompute(struct KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + // In Prepare() stage, in_tensor[0] may be of fp16 data type in fp16 mode, so move type checks here. + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *boxes_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes_tensor); + TensorC *boxidx_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxidx_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + int ret = CropAndResizeMallocTmpBuffer(crop_and_resize); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + float *boxes = (float *)boxes_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + int32_t *box_idx = (int32_t *)boxidx_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + + if (CheckCropAndResizeBoxIdx(box_idx, boxes_tensor->shape_[Index0], NNACLGetBatch(input_tensor)) != NNACL_OK) { + return NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID; + } + + ret = PrepareCropAndResizeBilinear(input_tensor->shape_, boxes, box_idx, output_tensor->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + int error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, CropAndResizeImpl, self, self->thread_nr_); + CropAndResizeFreeTmpBuffer(crop_and_resize); + return error_code; +} + +int CropAndResizeResize(KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + crop_and_resize->new_height_ = output->shape_[Index1]; + crop_and_resize->new_width_ = output->shape_[Index2]; + return NNACL_OK; +} + +int CropAndResizePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)malloc(sizeof(CropAndResizeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop_and_resize); + memset(crop_and_resize, 0, sizeof(CropAndResizeStruct)); + crop_and_resize->mapped_point_num_ = Num2; + crop_and_resize->base_.Prepare = CropAndResizePrepare; + crop_and_resize->base_.Resize = CropAndResizeResize; + crop_and_resize->base_.Compute = CropAndResizeCompute; + crop_and_resize->base_.Release = DefaultRelease; + return (KernelBase *)crop_and_resize; +} + +REG_KERNEL_CREATOR(PrimType_CropAndResize, kNumberTypeFloat32, CreateCropAndResize) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h new file mode 100644 index 00000000..6d5a0d19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/crop_and_resize.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_AND_RESIZE_H_ +#define NNACL_KERNEL_CROP_AND_RESIZE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int mapped_point_num_; + int batch_; + int new_height_; + int new_width_; + int *y_tops_; + int *y_bottoms_; + int *x_lefts_; + int *x_rights_; + float *y_bottom_weights_; + float *x_left_weights_; + float *line_buffer_; +} CropAndResizeStruct; + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_AND_RESIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c new file mode 100644 index 00000000..ce7a66a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.c @@ -0,0 +1,337 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/deconvolution.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/deconvolution_winograd.h" +#include "nnacl_c/kernel/deconvolution_depthwise.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_avx_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int DeConvMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int output_aligned_size = UP_ROUND(conv->compute_.out_c_, C8NUM) * sizeof(float); + size_t pack_weight_size = conv->compute_.in_c_ * conv->compute_.kernel_hw_ * output_aligned_size; + if (!conv->base_.train_session_) { + conv->packed_weight_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, pack_weight_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, output_aligned_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, output_aligned_size); + return NNACL_OK; +} + +void DeConvPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *weight_data = weight_tensor->data_ == NULL ? conv->origin_weight_ : weight_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(weight_data); + +#ifdef ENABLE_AVX + PackNHWCToCXHWNXFp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#else + PackNHWCToC8HWN8Fp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#endif +} + +int DeConvInitParam(DeConvStruct *deconv) { + ConvComputeParam *compute = &deconv->conv_.compute_; + deconv->matmul_.row_ = compute->in_hw_; + deconv->matmul_.deep_ = compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_c_, compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_ = compute->out_c_ * compute->kernel_hw_; + deconv->matmul_.row_align_ = UP_ROUND(deconv->matmul_.row_, deconv->matmul_.row_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(UP_ROUND(compute->out_c_, C8NUM), compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_align_ = UP_ROUND(compute->out_c_, C8NUM) * compute->kernel_hw_; + + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, UP_DIV(compute->out_c_, C8NUM)); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->conv_.base_.thread_nr_); +#ifdef ENABLE_AVX + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM * C3NUM), deconv->conv_.base_.thread_nr_) * C3NUM; +#else + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM), deconv->conv_.base_.thread_nr_); +#endif + return NNACL_OK; +} + +int DeConvRun(void *cdata, int task_id, float l, float r) { + DeConvStruct *deconv = (DeConvStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int total_thead_stride_ = task_id * deconv->thread_stride_; + int res_stride = UP_DIV(deconv->conv_.compute_.out_c_, C8NUM) - total_thead_stride_; + int oc = NNACL_MIN(deconv->thread_stride_, res_stride); + int cur_stride = deconv->thread_stride_ * C8NUM; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, C8NUM, NNACL_ERR); + int total_thead_stride_c8 = total_thead_stride_ * C8NUM; + res_stride = deconv->conv_.compute_.out_c_ - total_thead_stride_c8; + int oc_res = NNACL_MIN(cur_stride, res_stride); + if (oc <= 0 || oc_res <= 0) { + return NNACL_OK; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int plane_thead_stride_c8 = total_thead_stride_c8 * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.row_align_, NNACL_ERR); + int row_c8 = plane_thead_stride_c8 * deconv->matmul_.row_align_; + float *tmp_buffer = deconv->tmp_buffer_ + row_c8; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.deep_, NNACL_ERR); + int deep_c8 = plane_thead_stride_c8 * deconv->matmul_.deep_; + +#ifdef ENABLE_AVX + DeconvMatmulAvx(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->conv_.compute_.kernel_hw_); +#elif ENABLE_SSE + DeconvMatmulFloatSse(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, + oc * C8NUM * deconv->conv_.compute_.kernel_hw_); +#else + MatMulOpt(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, NULL, ActType_No, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->matmul_.col_, OutType_C8); +#endif + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.out_hw_, NNACL_OK); + DeConvPostFp32C8(tmp_buffer, deconv->pack_output_ + total_thead_stride_c8 * deconv->conv_.compute_.out_hw_, + (float *)deconv->conv_.bias_data_ + total_thead_stride_c8, + deconv->output_ptr_ + total_thead_stride_c8, oc_res, (ConvParameter *)deconv->conv_.base_.param_); + return NNACL_OK; +} + +void DeConvFreeRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (deconv->pack_output_ != NULL) { + env->Free(env->allocator_, deconv->pack_output_); + deconv->pack_output_ = NULL; + } + if (deconv->tmp_buffer_ != NULL) { + env->Free(env->allocator_, deconv->tmp_buffer_); + deconv->tmp_buffer_ = NULL; + } + if (deconv->pack_input_ != NULL) { + env->Free(env->allocator_, deconv->pack_input_); + deconv->pack_input_ = NULL; + } +} + +int DeConvInitRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int pack_output_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM) * deconv->conv_.compute_.out_hw_; + deconv->pack_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_output_); + + int tmp_buffer_size = deconv->matmul_.row_align_ * deconv->matmul_.col_align_; + deconv->tmp_buffer_ = (float *)env->Alloc(env->allocator_, tmp_buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tmp_buffer_); + + int pack_input_size = deconv->matmul_.row_align_ * deconv->matmul_.deep_; + deconv->pack_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_input_); + + return NNACL_OK; +} + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + + int resize_out_channel = NNACLGetChannel(input_tensor); + int filter_out_channel = NNACLGetBatch(filter_tensor); + if (filter_out_channel != resize_out_channel) { + return NNACL_DECONV_RESIZE_OC_INVALID; + } + return NNACL_OK; +} + +int DeConvResize(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvInitParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int DeConvCompute(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int error_code = ConvBaseRepackWeight(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = DeConvInitRunBuf(deconv); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + + float *src_in = (float *)self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_n_ - 1, deconv->conv_.compute_.in_c_, NNACL_ERR); + int input_bc = (deconv->conv_.compute_.in_n_ - 1) * deconv->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_hw_, input_bc, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, input_bc, NNACL_ERR); + for (int batch_index = 0; batch_index < deconv->conv_.compute_.in_n_; batch_index++) { + deconv->input_ptr_ = src_in + batch_index * deconv->conv_.compute_.in_hw_ * deconv->conv_.compute_.in_c_; + deconv->output_ptr_ = src_out + batch_index * deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#else + RowMajor2Col12Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#endif + + error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvRun, self, self->thread_nr_); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + } + + DeConvFreeRunBuf(deconv); + return NNACL_OK; +} + +int DeConvPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + // There could be weight dataType casting before Prepare, thus weight update is required. + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + +#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE) + deconv->matmul_.row_tile_ = C4NUM; +#else + deconv->matmul_.row_tile_ = C12NUM; +#endif + + if (self->train_session_) { + int output_aligned_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_c_, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int kernel_chw = deconv->conv_.compute_.in_c_ * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, output_aligned_size, NNACL_ERR); + size_t pack_weight_size = kernel_chw * output_aligned_size * sizeof(float); + self->work_size_ = pack_weight_size; + } + + if (self->in_[SECOND_INPUT]->data_ != NULL) { + int error_code = ConvBaseInitConvWeightBias(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + } else { + deconv->conv_.is_repack_ = true; + } + + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConv(ConvParameter *param) { + DeConvStruct *deconv = (DeConvStruct *)malloc(sizeof(DeConvStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv); + memset(deconv, 0, sizeof(DeConvStruct)); + deconv->conv_.malloc_weight_bias_ = DeConvMallocWeightBiasData; + deconv->conv_.pack_weight_ = DeConvPackWeight; + deconv->conv_.base_.Prepare = DeConvPrepare; + deconv->conv_.base_.Resize = DeConvResize; + deconv->conv_.base_.Release = DefaultRelease; + deconv->conv_.base_.Compute = DeConvCompute; + return &deconv->conv_; +} + +ConvolutionBaseStruct *SelectDeConv(ConvParameter *conv_param) { +#ifndef _WIN32 +#ifndef ENABLE_MCU + bool param_winograd_fit = (conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) && + (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1); + +#ifdef ENABLE_AVX + bool in_size_winograd_fit = conv_param->input_w_ * conv_param->input_h_ >= NNACL_DECONV_WINOGRAD_HW_MAX; + bool size_winograd_fit = (conv_param->kernel_w_ / conv_param->stride_w_ >= C2NUM || + conv_param->kernel_h_ / conv_param->stride_h_ >= C2NUM || conv_param->output_channel_ == 1); +#else + bool in_size_winograd_fit = true; + bool size_winograd_fit = + (conv_param->kernel_w_ / conv_param->stride_w_ > C2NUM || conv_param->kernel_h_ / conv_param->stride_h_ > C2NUM); +#endif + + if (param_winograd_fit && size_winograd_fit && in_size_winograd_fit) { + ConvolutionBaseStruct *kernel = CreateDeConvWinograd(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif +#endif + + return CreateDeConv(conv_param); +} + +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + if (conv_param->group_ == 1 && conv_param->input_channel_ == 1 && conv_param->output_channel_ == 1) { + conv = CreateDeConvDw(conv_param); + } else if (conv_param->group_ == 1) { + conv = SelectDeConv(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + conv = CreateDeConvDw(conv_param); + } + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv); + ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + return &conv->base_; +} + +REG_KERNEL_CREATOR(PrimType_Conv2dTransposeFusion, kNumberTypeFloat32, CreateConvolutionTranspose) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h new file mode 100644 index 00000000..a7f773a6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_H_ +#define NNACL_KERNEL_DECONVOLUTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" +#include "nnacl_c/kernel/matmul_struct.h" + +typedef struct DeConvStruct { + ConvolutionBaseStruct conv_; + MatmulComputeParam matmul_; + int thread_stride_; + float *pack_input_; + float *pack_output_; + float *tmp_buffer_; + float *input_ptr_; + float *output_ptr_; +} DeConvStruct; + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv); +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DECONVOLUTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c new file mode 100644 index 00000000..f612886b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.c @@ -0,0 +1,233 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/deconvolution_depthwise.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/conv_depthwise_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/kernel/deconvolution.h" + +int DeConvDwInitPackedInputOutput(DeConvDwStruct *deconv_dw) { + if (!deconv_dw->need_align_) { + return NNACL_OK; + } + ExecEnv *env = deconv_dw->conv_.base_.env_; + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + int ic4 = UP_ROUND(compute->in_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * compute->in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic4, NNACL_ERR); + int pack_input_size = input_bhw * ic4; + deconv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_input_); + + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc4, NNACL_ERR); + int pack_output_size = output_bhw * oc4; + deconv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_output_); + memset(deconv_dw->packed_output_, 0, pack_output_size * sizeof(float)); + + return NNACL_OK; +} + +int DeconvDwRun(void *cdata, int task_id, float l, float r) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + DeconvDwSWFp32(deconv_dw->packed_output_, deconv_dw->packed_input_, (float *)deconv_dw->conv_.packed_weight_, + (float *)deconv_dw->conv_.bias_data_, (ConvParameter *)deconv_dw->conv_.base_.param_, + &deconv_dw->sliding_, task_id); + return NNACL_OK; +} + +int DeConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + int oc4 = UP_ROUND(conv->compute_.out_c_, conv->compute_.tile_num_); + if (!conv->base_.train_session_) { + int pack_weight_size = oc4 * conv->compute_.kernel_hw_; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size); + deconv_dw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + } + + if (deconv_dw->conv_.bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc4 * sizeof(float)); + deconv_dw->conv_.bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + } + memset(deconv_dw->conv_.bias_data_, 0, oc4 * sizeof(float)); + return NNACL_OK; +} + +void DeConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +void DeConvDwFreePackedInputOutput(DeConvDwStruct *deconv_dw) { + if (deconv_dw->need_align_) { + ExecEnv *env = deconv_dw->conv_.base_.env_; + + env->Free(env->allocator_, deconv_dw->packed_input_); + deconv_dw->packed_input_ = NULL; + env->Free(env->allocator_, deconv_dw->packed_output_); + deconv_dw->packed_output_ = NULL; + } +} + +int DeConvDwPrepare(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + deconv_dw->conv_.compute_.tile_num_ = C4NUM; + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(compute->in_c_ != compute->out_c_, NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID); + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + + ConvBaseUpdateOriginWeightAndBias(&deconv_dw->conv_); + + if (self->train_session_) { + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + int pack_weight_size = oc4 * compute->kernel_hw_; + self->work_size_ = pack_weight_size; + } + + int ret = ConvBaseInitConvWeightBias(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + return NNACL_OK; +} + +void DeConvDwUpdateParam(ConvolutionBaseStruct *conv) { + TensorC *input = conv->base_.in_[FIRST_INPUT]; + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + conv_param->thread_num_ = conv->base_.thread_nr_; + conv_param->input_batch_ = NNACLGetBatch(output); + conv_param->input_h_ = NNACLGetHeight(output); + conv_param->input_w_ = NNACLGetWidth(output); + conv_param->input_channel_ = NNACLGetChannel(output); + conv_param->output_batch_ = NNACLGetBatch(input); + conv_param->output_h_ = NNACLGetHeight(input); + conv_param->output_w_ = NNACLGetWidth(input); + conv_param->output_channel_ = NNACLGetChannel(input); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(output); + compute->in_h_ = NNACLGetHeight(output); + compute->in_w_ = NNACLGetWidth(output); + compute->in_c_ = NNACLGetChannel(output); + compute->out_n_ = NNACLGetBatch(input); + compute->out_h_ = NNACLGetHeight(input); + compute->out_w_ = NNACLGetWidth(input); + compute->out_c_ = NNACLGetChannel(input); +} + +int DeConvDwResize(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + (void)ConvBaseUpdateComputeInfo(&deconv_dw->conv_); + + int ret = DeConvCheckvResizeValid(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + int tile_num = deconv_dw->conv_.compute_.tile_num_; + DeConvDwUpdateParam(&deconv_dw->conv_); + (void)InitSlidingParamConvDw(&deconv_dw->sliding_, (ConvParameter *)self->param_, tile_num); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, UP_DIV(deconv_dw->conv_.compute_.out_c_, tile_num)); + deconv_dw->need_align_ = deconv_dw->conv_.compute_.in_c_ % tile_num != 0; + + ret = ConvBasePrepare(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int DeConvDwCompute(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *in_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *out_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + int ret = ConvBaseRepackWeight(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvDwInitPackedInputOutput(deconv_dw); + if (ret != NNACL_OK) { + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; + } + + if (deconv_dw->need_align_) { + PackNHWCToNHWC4Fp32(in_data, deconv_dw->packed_input_, compute->in_n_, compute->in_hw_, compute->in_c_); + } else { + deconv_dw->packed_input_ = in_data; + deconv_dw->packed_output_ = out_data; + memset(deconv_dw->packed_output_, 0, NNACLGetSize(out_tensor)); + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeconvDwRun, self, self->thread_nr_); + + if (deconv_dw->need_align_) { + PackNHWCXToNHWCFp32(deconv_dw->packed_output_, out_data, compute->out_n_, compute->out_hw_, compute->out_c_, + compute->tile_num_); + } + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; +} + +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)malloc(sizeof(DeConvDwStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_dw); + memset(deconv_dw, 0, sizeof(DeConvDwStruct)); + + deconv_dw->conv_.pack_weight_ = DeConvDwPackWeight; + deconv_dw->conv_.malloc_weight_bias_ = DeConvDwMallocWeightBiasData; + deconv_dw->conv_.base_.Prepare = DeConvDwPrepare; + deconv_dw->conv_.base_.Resize = DeConvDwResize; + deconv_dw->conv_.base_.Release = DefaultRelease; + deconv_dw->conv_.base_.Compute = DeConvDwCompute; + return &deconv_dw->conv_; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h new file mode 100644 index 00000000..b929109e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_depthwise.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct DeConvDwStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + bool need_align_; + float *packed_input_; + float *packed_output_; +} DeConvDwStruct; + +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param); + +#endif // NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c new file mode 100644 index 00000000..727cbc19 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.c @@ -0,0 +1,551 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include "nnacl_c/kernel/deconvolution_winograd.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/fp32/deconv_winograd_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/kernel/deconvolution.h" + +void DeConvWinogradFreeResizeBuf(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->tmp_buffer_ != NULL) { + free(unit->tmp_buffer_); + unit->tmp_buffer_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.b_buffer_ != NULL) { + free(unit->winograd_.b_buffer_); + unit->winograd_.b_buffer_ = NULL; + } + } + } + + for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { + DeConvWgABuffer *wg = ¶m->a_buffer_[i]; + if (wg->buf_init_) { + if (wg->dest_buffer_ != NULL) { + free(wg->dest_buffer_); + wg->dest_buffer_ = NULL; + } + if (wg->middle_buffer_ != NULL) { + free(wg->middle_buffer_); + wg->middle_buffer_ = NULL; + } + } + wg->buf_init_ = false; + } + + if (deconv->tile_input_ != NULL) { + free(deconv->tile_input_); + deconv->tile_input_ = NULL; + } +} + +void DeConvWinogradFreeDeconvParam(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + + if (unit->weight_ != NULL) { + free(unit->weight_); + unit->weight_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + } + } + + if (param->compute_units_ != NULL) { + free(param->compute_units_); + param->compute_units_ = NULL; + } +} + +int DeConvWinogradInitParameter(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + int thread_num = deconv->conv_.base_.thread_nr_; + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + + param->input_plane_ = compute->in_hw_; + param->output_plane_ = compute->out_hw_; + + param->in_tile_w_count_ = UP_DIV(compute->in_w_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_w_count_); + param->in_tile_h_count_ = UP_DIV(compute->in_h_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_h_count_); + param->in_tile_count_ = UP_DIV(param->in_tile_w_count_ * param->in_tile_h_count_, WINOGRAD_DEFAULT_TILE); + + deconv->conv_.base_.thread_nr_ = NNACL_MAX(1, deconv->conv_.base_.thread_nr_); + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, param->in_tile_count_); + + deconv->thread_num_hw_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, compute->out_hw_); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->thread_num_hw_); + deconv->thread_stride_hw_ = UP_DIV(compute->out_hw_, deconv->thread_num_hw_); + + int total_ic_up = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, total_ic_up, NNACL_ERR); + int size = deconv->conv_.base_.thread_nr_ * total_ic_up; + NNACL_CHECK_MALLOC_SIZE(size * sizeof(float)); + deconv->tile_input_ = (float *)malloc(size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_input_); + (void)memset(deconv->tile_input_, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_w_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_h_, NNACL_ERR); + param->out_tile_w_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_w_ + compute->kernel_w_; + param->out_tile_h_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_h_ + compute->kernel_h_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->use_winograd_) { + if (!param->a_buffer_[unit->winograd_.kh_].buf_init_) { + param->a_buffer_[unit->winograd_.kh_].buf_init_ = true; + size = unit->winograd_.kh_ * unit->winograd_.kw_ * WINOGRAD_DEFAULT_TILE * param->ic_up_; + + param->a_buffer_[unit->winograd_.kh_].middle_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].middle_buffer_); + + param->a_buffer_[unit->winograd_.kh_].dest_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].dest_buffer_); + } + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_up_ * WINOGRAD_DEFAULT_TILE; + unit->winograd_.b_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->winograd_.b_buffer_); + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_div_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } else { + size = param->oc_div_ * unit->w_size_ * unit->h_size_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } + } + + return NNACL_OK; +} + +int DeConvWgFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *conv_param = (ConvParameter *)deconv->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + for (int tile_index = task_id; tile_index < param->in_tile_count_; tile_index += deconv->conv_.base_.thread_nr_) { + int size = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + float *tile_in = deconv->tile_input_ + task_id * size; + size = param->out_tile_w_ * param->out_tile_h_ * WINOGRAD_DEFAULT_TILE * param->oc_div_ * compute->tile_num_; + float *tile_out = deconv->tile_output_ + task_id * size; + (void)memset(tile_out, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(tile_index, WINOGRAD_DEFAULT_TILE, NNACL_ERR); + int start_index = tile_index * WINOGRAD_DEFAULT_TILE; + int cal_count = NNACL_MIN(WINOGRAD_DEFAULT_TILE, param->in_tile_w_count_ * param->in_tile_h_count_ - start_index); + + int ret = DeconvWg(deconv->nhwc_input_, tile_in, tile_out, start_index, cal_count, conv_param, param, task_id); + if (ret != NNACL_OK) { + return ret; + } + + (void)pthread_mutex_lock(&deconv->lock_); + (void)DeconvWgPost(tile_out, deconv->nc4hw4_output_, conv_param, param, cal_count, tile_index); + (void)pthread_mutex_unlock(&deconv->lock_); + } + return NNACL_OK; +} + +int DeConvWgPostFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, deconv->thread_stride_hw_, NNACL_ERR); + int output_stride_plane = task_id * deconv->thread_stride_hw_; + int rest_plane = compute->out_hw_ - output_stride_plane; + int current_plane = MSMIN(rest_plane, deconv->thread_stride_hw_); + if (current_plane <= 0) { + return NNACL_OK; + } + + ActType act = ((ConvParameter *)deconv->conv_.base_.param_)->act_type_; + float *bias = (float *)deconv->conv_.bias_data_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.tile_num_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.out_c_, NNACL_ERR); + WinogradPostConvFuncFp32CX(deconv->nc4hw4_output_ + output_stride_plane * compute->tile_num_, + deconv->nhwc_output_ + output_stride_plane * compute->out_c_, bias, compute->out_c_, + current_plane, compute->out_hw_, act); + return NNACL_OK; +} + +int DeConvWinogradInitComputeParam(DeConvWinogradStruct *deconv) { + deconv->valid_weight_shape_ = CheckShaleValid(&deconv->conv_.base_.in_[SECOND_INPUT], Num1); + if (deconv->valid_weight_shape_ == false) { + return NNACL_OK; + } + + ConvComputeParam *compute = &deconv->conv_.compute_; + DeConvParam *param = &deconv->param_; + + param->kernel_plane_ = compute->kernel_hw_; + param->ic_div_ = UP_DIV(compute->in_c_, compute->tile_num_); + param->oc_div_ = UP_DIV(compute->out_c_, compute->tile_num_); + param->ic_up_ = param->ic_div_ * compute->tile_num_; + param->oc_up_ = param->oc_div_ * compute->tile_num_; + + param->compute_size_ = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_h < compute->kernel_h_ && si_w < compute->kernel_w_) { + param->compute_size_++; + } + } + } + + size_t size = (size_t)param->compute_size_ * sizeof(DeConvComputeUnit); + param->compute_units_ = (DeConvComputeUnit *)(malloc(size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->compute_units_); + + int cur_count = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + if (si_h >= compute->kernel_h_) { + continue; + } + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_w >= compute->kernel_w_) { + continue; + } + + int h_size = 1 + (compute->kernel_h_ - si_h - 1) / compute->stride_h_; + int w_size = 1 + (compute->kernel_w_ - si_w - 1) / compute->stride_w_; + + DeConvComputeUnit unit; + unit.winograd_.AT_ = NULL; + unit.winograd_.BT_ = NULL; + + unit.h_start_ = si_h; + unit.w_start_ = si_w; + unit.h_size_ = h_size; + unit.w_size_ = w_size; + + unit.use_winograd_ = false; + if (h_size == w_size) { + unit.winograd_.k_ = unit.h_size_; + unit.winograd_.i_ = WINOGRAD_DEFAULT_UNIT; + unit.winograd_.o_ = WINOGRAD_DEFAULT_UNIT + unit.h_size_ - 1; + unit.winograd_.kh_ = unit.h_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.winograd_.kw_ = unit.w_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.use_winograd_ = unit.winograd_.kh_ < WINOGRAD_MAX_COUNT && unit.winograd_.kw_ < WINOGRAD_MAX_COUNT; + } + if (unit.use_winograd_) { + unit.winograd_.b_buffer_ = NULL; + unit.weight_ = malloc(unit.winograd_.kh_ * unit.winograd_.kw_ * param->oc_up_ * param->ic_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } else { + unit.weight_ = malloc(h_size * w_size * param->ic_up_ * param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } + unit.tmp_buffer_ = NULL; + param->compute_units_[cur_count] = unit; + cur_count++; + } + } + return NNACL_OK; +} + +int DeConvWinogradInitDataParam(DeConvWinogradStruct *deconv) { + TensorC *weight_tensor = deconv->conv_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + float *nhwc_weight = weight_tensor->data_; + if (nhwc_weight == NULL) { + deconv->conv_.is_repack_ = true; + return NNACL_OK; + } + + DeConvParam *param = &deconv->param_; + + /* unit data : weight & winograd data */ + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + int ret = PackDeConvWgDataFp32(nhwc_weight, unit, (ConvParameter *)deconv->conv_.base_.param_, param); + if (ret != NNACL_OK) { + return ret; + } + } + + /* bias */ + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + if (deconv->conv_.bias_data_ != NULL) { + env->Free(env->allocator_, deconv->conv_.bias_data_); + deconv->conv_.bias_data_ = NULL; + } + deconv->conv_.bias_data_ = env->Alloc(env->allocator_, param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->conv_.bias_data_); + (void)memset(deconv->conv_.bias_data_, 0, param->oc_up_ * sizeof(float)); + + if (deconv->conv_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_tensor = deconv->conv_.base_.in_[THIRD_INPUT]; + if (bias_tensor->shape_size_ == Num1 && NNACLGetElementNum(bias_tensor) == deconv->conv_.compute_.out_c_) { + (void)memcpy(deconv->conv_.bias_data_, bias_tensor->data_, deconv->conv_.compute_.out_c_ * sizeof(float)); + } + } + return NNACL_OK; +} + +int DeConvWinogradInitRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + int size = deconv->param_.oc_up_ * deconv->conv_.compute_.out_hw_; + deconv->nc4hw4_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->nc4hw4_output_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->param_.out_tile_w_, deconv->param_.out_tile_h_, NNACL_ERR); + int out_tile_hw = deconv->param_.out_tile_w_ * deconv->param_.out_tile_h_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, out_tile_hw, NNACL_ERR); + int total_out_tile_hw = deconv->conv_.base_.thread_nr_ * out_tile_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(WINOGRAD_DEFAULT_TILE, deconv->param_.oc_up_, NNACL_ERR); + int tile_oc_up = WINOGRAD_DEFAULT_TILE * deconv->param_.oc_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_out_tile_hw, tile_oc_up, NNACL_ERR); + size = total_out_tile_hw * tile_oc_up; + deconv->tile_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_output_); + + return NNACL_OK; +} + +void DeConvWinogradFreeRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + if (deconv->nc4hw4_output_ != NULL) { + env->Free(env->allocator_, deconv->nc4hw4_output_); + deconv->nc4hw4_output_ = NULL; + } + + if (deconv->tile_output_ != NULL) { + env->Free(env->allocator_, deconv->tile_output_); + deconv->tile_output_ = NULL; + } +} + +int InitTrainComputeInit(DeConvWinogradStruct *deconv) { + if (!deconv->valid_weight_shape_) { + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + if (!deconv->valid_weight_shape_ || DeConvWinogradInitParameter(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE; + } + } + + if (deconv->conv_.is_repack_ && DeConvWinogradInitDataParam(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK; + } + + return NNACL_OK; +} + +int DeConvWinogradPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->stride_h_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + NNACL_CHECK_FALSE(compute->stride_w_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + +#ifdef ENABLE_AVX + compute->tile_num_ = C8NUM; +#else + compute->tile_num_ = C4NUM; +#endif + + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + if (deconv->valid_weight_shape_) { + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + // when input data is const tensor, save data in kernel + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (NNACLIsConst(input_tensor)) { + deconv->origin_input_ = (float *)malloc(NNACLGetSize(input_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->origin_input_); + (void)memcpy(deconv->origin_input_, input_tensor->data_, NNACLGetSize(input_tensor)); + } + return NNACL_OK; +} + +int DeConvWinogradResize(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + DeConvWinogradFreeResizeBuf(deconv); + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + if (!deconv->valid_weight_shape_) { + ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + if (!deconv->valid_weight_shape_) { + return NNACL_OK; + } + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + ret = DeConvWinogradInitParameter(deconv); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, deconv->conv_.compute_.out_c_, NNACL_ERR); + int output_chw = deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + if (output_chw <= kDeconvWinogradMaxPixel) { + self->thread_nr_ = NNACL_MIN(self->thread_nr_, Num3); + } + return NNACL_OK; +} + +int DeConvWinogradRelease(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + DeConvWinogradFreeResizeBuf(deconv); + DeConvWinogradFreeDeconvParam(deconv); + + if (deconv->origin_input_ != NULL) { + free(deconv->origin_input_); + deconv->origin_input_ = NULL; + } + return NNACL_OK; +} + +int DeConvWinogradCompute(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute_ = &deconv->conv_.compute_; + + int ret = DeConvWinogradInitRunBuf(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + ret = InitTrainComputeInit(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *src_in = deconv->origin_input_ != NULL ? deconv->origin_input_ : (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int input_chw = compute_->in_hw_ * compute_->in_c_; + int output_chw = compute_->out_hw_ * compute_->out_c_; + for (int batch_index = 0; batch_index < compute_->in_n_; batch_index++) { + deconv->nhwc_input_ = src_in + batch_index * input_chw; + deconv->nhwc_output_ = src_out + batch_index * output_chw; + + (void)memset(deconv->nc4hw4_output_, 0, compute_->out_hw_ * param->oc_div_ * compute_->tile_num_ * sizeof(float)); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + /* post bias activate and nhwc */ + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgPostFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + } + + DeConvWinogradFreeRunBuf(deconv); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param) { + DeConvWinogradStruct *deconv_winograd = (DeConvWinogradStruct *)malloc(sizeof(DeConvWinogradStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_winograd); + memset(deconv_winograd, 0, sizeof(DeConvWinogradStruct)); + + deconv_winograd->conv_.base_.Prepare = DeConvWinogradPrepare; + deconv_winograd->conv_.base_.Resize = DeConvWinogradResize; + deconv_winograd->conv_.base_.Release = DeConvWinogradRelease; + deconv_winograd->conv_.base_.Compute = DeConvWinogradCompute; + return &deconv_winograd->conv_; +} +#endif +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h new file mode 100644 index 00000000..eabdffdf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/deconvolution_winograd.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ + +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/convolution_base.h" + +#define kDeconvWinogradMaxPixel 3145728 +#define WINOGRAD_DEFAULT_UNIT 3 +#define WINOGRAD_DEFAULT_TILE 8 +#define WINOGRAD_MAX_COUNT 8 + +typedef struct DeConvWinogradStruct { + ConvolutionBaseStruct conv_; + DeConvParam param_; + pthread_mutex_t lock_; + int thread_num_hw_; + int thread_stride_hw_; + float *nhwc_input_; + float *nhwc_output_; + float *tile_input_; + float *tile_output_; + float *origin_input_; + float *nc4hw4_output_; + bool valid_weight_shape_; +} DeConvWinogradStruct; + +#define NNACL_DECONV_WINOGRAD_HW_MAX 2000 + +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param); +#endif +#endif +#endif // NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c new file mode 100644 index 00000000..814336c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.c @@ -0,0 +1,55 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/default_kernel_base.h" + +int DefaultPrepare3In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare3In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare2In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +int DefaultRelease(KernelBase *self) { return NNACL_OK; } diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h new file mode 100644 index 00000000..ba666539 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/default_kernel_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ +#define NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +int DefaultPrepare3In2Out(KernelBase *self); +int DefaultPrepare1In1Out(KernelBase *self); +int DefaultPrepare2In1Out(KernelBase *self); +int DefaultPrepare1In2Out(KernelBase *self); +int DefaultPrepare3In1Out(KernelBase *self); +int DefaultResize(KernelBase *self); +int DefaultRelease(KernelBase *self); + +#endif // NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c new file mode 100644 index 00000000..3652b9fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.c @@ -0,0 +1,80 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/depth_to_space.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/base/depth_to_space_base.h" + +int DepthToSpaceResize(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + DepthToSpaceArgs *args = &depth_to_space->args_; + + TensorC *input = self->in_[FIRST_INPUT]; + int32_t in_strides[DIMENSION_4D] = {0}; + ComputeStrides(input->shape_, in_strides, input->shape_size_); + args->in_stride_dim0_ = in_strides[Index0]; + args->in_stride_dim1_ = in_strides[Index1]; + args->in_stride_dim2_ = in_strides[Index2]; + + TensorC *output = self->out_[OUTPUT_INDEX]; + int32_t out_strides[DIMENSION_4D] = {0}; + ComputeStrides(output->shape_, out_strides, output->shape_size_); + args->out_stride_dim0_ = out_strides[Index0]; + args->out_stride_dim1_ = out_strides[Index1]; + args->out_stride_dim2_ = out_strides[Index2]; + return NNACL_OK; +} + +int DepthToSpaceCompute(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + int mode = ((DepthToSpaceParameter *)self->param_)->mode_; + + TensorC *input = self->in_[FIRST_INPUT]; + TensorC *output = self->out_[OUTPUT_INDEX]; + + if (mode == 0) { + // RCD + DepthToSpaceForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else if (mode == 1) { + // CRD + DepthToSpaceCRDForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else { + return NNACL_DEPTH_TO_SPACE_INVALID_MODE; + } + return NNACL_OK; +} + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)malloc(sizeof(DepthToSpaceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(depth_to_space); + memset(depth_to_space, 0, sizeof(DepthToSpaceStruct)); + + depth_to_space->args_.data_type_size_ = DataTypeCSize(data_type); + depth_to_space->args_.block_size_ = ((DepthToSpaceParameter *)param)->block_size_; + depth_to_space->base_.Release = DefaultRelease; + depth_to_space->base_.Prepare = DefaultPrepare1In1Out; + depth_to_space->base_.Resize = DepthToSpaceResize; + depth_to_space->base_.Compute = DepthToSpaceCompute; + return (KernelBase *)depth_to_space; +} + +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat32, CreateDepthToSpace) +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat16, CreateDepthToSpace) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h new file mode 100644 index 00000000..969ff818 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/depth_to_space.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEPTH_TO_SPACE_H_ +#define NNACL_KERNEL_DEPTH_TO_SPACE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct DepthToSpaceArgs { + int32_t in_stride_dim0_; + int32_t in_stride_dim1_; + int32_t in_stride_dim2_; + int32_t out_stride_dim0_; + int32_t out_stride_dim1_; + int32_t out_stride_dim2_; + uint8_t data_type_size_; + int32_t block_size_; +} DepthToSpaceArgs; + +typedef struct DepthToSpaceStruct { + KernelBase base_; + DepthToSpaceArgs args_; +} DepthToSpaceStruct; + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c new file mode 100644 index 00000000..e914d553 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.c @@ -0,0 +1,86 @@ +/** + * Copyright 2022-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/exp.h" +#include +#include "nnacl_c/exp_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/exp_fp16.h" +#endif + +int ExpRunImpl(void *cdata, int task_id, float l, float r) { + ExpStruct *exp = (ExpStruct *)cdata; + return exp->Exp(exp->base_.in_[FIRST_INPUT]->data_, exp->base_.out_[OUTPUT_INDEX]->data_, exp, task_id); +} + +int ExpResize(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + exp->element_num_ = NNACLGetElementNum(exp->base_.in_[FIRST_INPUT]); + return NNACL_OK; +} + +int ExpPrepare(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < 1 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + float log_base = (param->base_ == -1) ? 1 : logf(param->base_); + float epsilon = 0.000001; + exp->in_scale_ = param->scale_ * log_base; + if (param->shift_ == 0) { + exp->out_scale_ = 1; + } else { + if (fabs(log_base - 1) < epsilon) { + exp->out_scale_ = expf(param->shift_); + } else { + exp->out_scale_ = powf(param->base_, param->shift_); + } + } + + return NNACL_OK; +} + +int ExpCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ExpRunImpl, self, self->thread_nr_); +} + +KernelBase *CreateExp(OpParameter *param, int data_type) { + ExpStruct *exp = (ExpStruct *)malloc(sizeof(ExpStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(exp); + exp->base_.Prepare = ExpPrepare; + exp->base_.Resize = ExpResize; + exp->base_.Release = DefaultRelease; + exp->base_.Compute = ExpCompute; + exp->Exp = ExpFusionFp32; +#ifdef ENABLE_FP16 + if (data_type == kNumberTypeFloat16) { + exp->Exp = ExpFusionFp16; + } +#endif + return (KernelBase *)exp; +} + +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat32, CreateExp) +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat16, CreateExp) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h new file mode 100644 index 00000000..35a5d2d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/exp.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_EXP_H_ +#define NNACL_KERNEL_EXP_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ExpStruct { + KernelBase base_; + float in_scale_; + float out_scale_; + int element_num_; + int (*Exp)(const void *in, void *out, const struct ExpStruct *exp, int task_id); +} ExpStruct; + +KernelBase *CreateExp(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_EXP_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c new file mode 100644 index 00000000..1b6877a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" + +typedef struct ArithmeticCompareF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + int (*optimzie_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +} ArithmeticCompareF16Funcions; + +typedef struct ArithmeticCompareF16Struct { + ArithmeticF16Struct arithmetic_f16_; + ArithmeticCompareF16Funcions functions_; +} ArithmeticCompareF16Struct; + +void InitArithmeticCompareF16RunFunction(KernelBase *base) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + ArithmeticParameter *arithmetic_param = (ArithmeticParameter *)base->param_; + + ArithmeticCompareF16Funcions arithmetic_cp_fun_table_fp16[] = { + {PrimType_NotEqual, ActType_No, ElementNotEqualFp16, ElementOptNotEqualFp16}, + {PrimType_Equal, ActType_No, ElementEqualFp16, ElementOptEqualFp16}, + {PrimType_Less, ActType_No, ElementLessFp16, ElementOptLessFp16}, + {PrimType_LessEqual, ActType_No, ElementLessEqualFp16, ElementOptLessEqualFp16}, + {PrimType_Greater, ActType_No, ElementGreaterFp16, ElementOptGreaterFp16}, + {PrimType_GreaterEqual, ActType_No, ElementGreaterEqualFp16, ElementOptGreaterEqualFp16}}; + + size_t length = sizeof(arithmetic_cp_fun_table_fp16) / sizeof(ArithmeticCompareF16Funcions); + for (size_t i = 0; i < length; i++) { + if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == + arithmetic_compare_f16->arithmetic_f16_.arithmetic_.primitive_type_ && + arithmetic_cp_fun_table_fp16[i].activation_type_ == arithmetic_param->activation_type_) { + arithmetic_compare_f16->functions_ = arithmetic_cp_fun_table_fp16[i]; + return; + } + } +} + +int ArithmeticCompareF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, + int64_t size) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + + if (arithmetic_compare_f16->arithmetic_f16_.arithmetic_.scalar_opt_) { + bool first_scalar = arithmetic_compare_f16->arithmetic_f16_.arithmetic_.in_elements_num0_ == 1; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.optimzie_); + return arithmetic_compare_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size, first_scalar); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.compute_); + return arithmetic_compare_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size); +} +int ArithmeticCompareF16Compute(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticF16Compute(self); +} + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = + (ArithmeticCompareF16Struct *)malloc(sizeof(ArithmeticCompareF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_compare_f16); + memset(arithmetic_compare_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_compare_f16->arithmetic_f16_.arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompareF16Compute; + + arithmetic->execute_ = ArithmeticCompareF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticCompareF16RunFunction; + + return (KernelBase *)arithmetic_compare_f16; +} + +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h new file mode 100644 index 00000000..0727c77d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_compare_f16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c new file mode 100644 index 00000000..10ca6601 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.c @@ -0,0 +1,195 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/arithmetic_fp16.h" +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void InitArithmeticF16RunFunction(KernelBase *base) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + ArithmeticF16Funcions f16_fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulReluFp16, ElementOptMulReluFp16}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, + {PrimType_MulFusion, ActType_No, ElementMulFp16, ElementOptMulFp16}, + {PrimType_AddFusion, ActType_Relu, ElementAddReluFp16, ElementOptAddReluFp16}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, + {PrimType_AddFusion, ActType_No, ElementAddFp16, ElementOptAddFp16}, + {PrimType_SubFusion, ActType_Relu, ElementSubReluFp16, ElementOptSubReluFp16}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, + {PrimType_SubFusion, ActType_No, ElementSubFp16, ElementOptSubFp16}, + {PrimType_DivFusion, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_DivFusion, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_RealDiv, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_RealDiv, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_FloorMod, ActType_No, ElementFloorModFp16, ElementOptFloorModFp16}, + {PrimType_FloorDiv, ActType_No, ElementFloorDivFp16, ElementOptFloorDivFp16}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16}, + {PrimType_Maximum, ActType_No, ElementMaximumFp16, ElementOptMaximumFp16}, + {PrimType_Minimum, ActType_No, ElementMinimumFp16, ElementOptMinimumFp16}}; + + size_t length = sizeof(f16_fun_table) / sizeof(ArithmeticF16Funcions); + for (size_t i = 0; i < length; i++) { + if (f16_fun_table[i].primitive_type_ == arithmetic_f16->arithmetic_.primitive_type_ && + f16_fun_table[i].activation_type_ == + ((ArithmeticParameter *)(arithmetic_f16->arithmetic_.base_.param_))->activation_type_) { + arithmetic_f16->functions_ = f16_fun_table[i]; + return; + } + } +} + +int ArithmeticF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + if (arithmetic_f16->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.optimzie_); + return arithmetic_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (float16_t *)output, size, + arithmetic_f16->arithmetic_.in_elements_num0_ == 1); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.compute_); + return arithmetic_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, (float16_t *)output, + size); +} + +int ArithmeticF16Resize(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + arithmetic->in_data_size_ = sizeof(float16_t); + arithmetic->out_data_size_ = sizeof(float16_t); + if (arithmetic->in_elements_num1_ != 1 && arithmetic->in_elements_num0_ != 1) { + if (arithmetic->a_matrix_.is_const_ && self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + if (arithmetic->b_matrix_.is_const_ && self->in_[SECOND_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + } + return ArithmeticResize(self); +} + +void FreeArithmeticF16Buffers(ArithmeticF16Struct *arithmetic_f16) { + for (int i = 0; i < THREE_TENSOR; i++) { + if (arithmetic_f16->tmp_buffer_[i] != NULL) { + arithmetic_f16->arithmetic_.base_.env_->Free(arithmetic_f16->arithmetic_.base_.env_->allocator_, + arithmetic_f16->tmp_buffer_[i]); + arithmetic_f16->tmp_buffer_[i] = NULL; + } + } +} + +int ArithmeticF16Compute(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + + int in0_data_type = self->in_[FIRST_INPUT]->data_type_; + int in1_data_type = self->in_[SECOND_INPUT]->data_type_; + int out_data_type = self->out_[OUTPUT_INDEX]->data_type_; + + NNACL_CHECK_FALSE(in0_data_type != kNumberTypeFloat32 && in0_data_type != kNumberTypeFloat16, + NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_FALSE(in1_data_type != kNumberTypeFloat16 && in1_data_type != kNumberTypeFloat32, + NNACL_UNSUPPORTED_DATA_TYPE); + + if (!arithmetic_f16->arithmetic_.a_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.a_matrix_.data_ = GetOrAllocFp16Data(self->in_[FIRST_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[FIRST_INPUT] = + in0_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.a_matrix_.data_; + } + + if (!arithmetic_f16->arithmetic_.b_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.b_matrix_.data_ = GetOrAllocFp16Data(self->in_[SECOND_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[SECOND_INPUT] = + in1_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.b_matrix_.data_; + } + + arithmetic_f16->arithmetic_.c_matrix_.data_ = GetOrAllocFp16Data(self->out_[OUTPUT_INDEX], self->env_, false); + arithmetic_f16->tmp_buffer_[THIRD_INPUT] = + out_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.c_matrix_.data_; + + int ret = ArithmeticCompute(self); + if (ret == NNACL_OK && out_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->arithmetic_.c_matrix_.data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + Float16ToFloat32((float16_t *)(arithmetic_f16->arithmetic_.c_matrix_.data_), + (float *)(self->out_[OUTPUT_INDEX]->data_), NNACLGetElementNum(self->out_[OUTPUT_INDEX])); + } + + FreeArithmeticF16Buffers(arithmetic_f16); + return NNACL_OK; +} + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)malloc(sizeof(ArithmeticF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_f16); + memset(arithmetic_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_f16->arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticF16Compute; + + arithmetic->execute_ = ArithmeticF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticF16RunFunction; + + return (KernelBase *)arithmetic_f16; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat16, CreateArithmeticF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h new file mode 100644 index 00000000..9e6f8fcc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/arithmetic_f16.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_ARITHMETIC_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/arithmetic.h" + +typedef struct ArithmeticF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele); + int (*optimzie_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele, bool first_scalar); +} ArithmeticF16Funcions; + +typedef struct ArithmeticF16Struct { + ArithmeticStruct arithmetic_; + ArithmeticF16Funcions functions_; + void *tmp_buffer_[THREE_TENSOR]; /* in_size + out_size */ +} ArithmeticF16Struct; + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type); +int ArithmeticF16Resize(KernelBase *self); +int ArithmeticF16Compute(KernelBase *self); + +#endif // MINDSPORE_NNACL_KERNEL_F16_ARITHMETIC_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c new file mode 100644 index 00000000..ab6f67c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.c @@ -0,0 +1,132 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/concat_f16.h" +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +typedef struct ConcatF16Struct { + ConcatStruct concat_; + void **tmp_buffer_; /* in_size + out_size */ +} ConcatF16Struct; + +int ConcatEnsureFp16InputsAndOutput(ConcatF16Struct *concat_f16) { + ConcatStruct *concat = &concat_f16->concat_; + + int tmp_buffer_size = (concat->base_.in_size_ + concat->base_.out_size_) * sizeof(float16_t *); + concat_f16->tmp_buffer_ = concat->base_.env_->Alloc(concat->base_.env_->allocator_, tmp_buffer_size); + NNACL_CHECK_NULL_RETURN_ERR(concat_f16->tmp_buffer_); + memset(concat_f16->tmp_buffer_, 0, tmp_buffer_size); + + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + + concat->inputs_ptr_[i] = GetOrAllocFp16Data(concat->base_.in_[i], concat->base_.env_, true); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_[i]); + if (concat->base_.in_[i]->data_type_ == kNumberTypeFloat32 || + concat->base_.in_[i]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[i] = concat->inputs_ptr_[i]; + } + } + + concat->output_ = GetOrAllocFp16Data(concat->base_.out_[OUTPUT_INDEX], concat->base_.env_, false); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->output_); + if (concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32 || + concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[concat->base_.in_size_] = concat->output_; + } + return NNACL_OK; +} + +int ConcatFp16Run(void *cdata, int task_id, float l, float r) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + return DoConcat(concat, task_id); +} + +void ConcatF16FreeTmpBuffer(ConcatF16Struct *concat_f16) { + if (concat_f16->tmp_buffer_ != NULL) { + /* free tmp_buffer_[i] */ + for (int i = 0; i < (concat_f16->concat_.base_.in_size_ + concat_f16->concat_.base_.out_size_); i++) { + if (concat_f16->tmp_buffer_[i] != NULL) { + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_[i]); + } + concat_f16->tmp_buffer_[i] = NULL; + } + + /* free tmp_buffer_ */ + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_); + concat_f16->tmp_buffer_ = NULL; + } +} + +int ConcatF16Compute(KernelBase *self) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + int ret = ConcatEnsureFp16InputsAndOutput(concat_f16); + if (ret != NNACL_OK) { + ConcatF16FreeTmpBuffer(concat_f16); + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatFp16Run, self, self->thread_nr_); + if (ret == NNACL_OK) { + TensorC *output_tensor = concat->base_.out_[FIRST_INPUT]; + if (output_tensor->data_type_ == kNumberTypeFloat32 || output_tensor->data_type_ == kNumberTypeFloat) { + float *output = concat->base_.out_[FIRST_INPUT]->data_; + if (output == NULL) { + ret = NNACL_CONCAT_F16_OUTPUT_DATA_INVALID; + } else { + Float16ToFloat32((float16_t *)concat->output_, output, NNACLGetElementNum(output_tensor)); + } + } + } + + ConcatF16FreeTmpBuffer(concat_f16); + return ret; +} + +KernelBase *CreateConcatF16(OpParameter *param, int data_type) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)malloc(sizeof(ConcatF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(concat_f16); + memset(concat_f16, 0, sizeof(ConcatF16Struct)); + + ConcatStruct *concat = &concat_f16->concat_; + concat->data_type_ = kNumberTypeFloat16; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatF16Compute; + concat_f16->tmp_buffer_ = NULL; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat16, CreateConcatF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h new file mode 100644 index 00000000..7a6eb5af --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/concat_f16.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#define MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateConcatF16(OpParameter *param, int data_type); + +#endif // MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c new file mode 100644 index 00000000..0f1cb06e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.c @@ -0,0 +1,118 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/reduce_f16.h" +#include "nnacl_c/fp16/reduce_fp16.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +typedef struct ReduceF16Compute { + int type_; + int (*f16_reducer_)(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data, + float16_t *dst_data, const int tid, const int thread_num); +} ReduceF16Compute; + +typedef struct ReduceF16Struct { + ReduceStruct reduce_; + ReduceF16Compute compute_; +} ReduceF16Struct; + +int CallReduceF16Unit(KernelBase *base, int task_id) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->compute_.f16_reducer_); + + return reduce_f16->compute_.f16_reducer_(reduce_f16->reduce_.outer_size_, reduce_f16->reduce_.inner_size_, + reduce_f16->reduce_.axis_size_, + (const float16_t *)reduce_f16->reduce_.src_data_, + (float16_t *)reduce_f16->reduce_.dst_data_, task_id, base->thread_nr_); +} + +void InitialReduceF16KernelList(KernelBase *base) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceF16Compute func_list[] = {{Reduce_Sum, ReduceSumFp16}, {Reduce_Mean, ReduceMeanFp16}, + {Reduce_Max, ReduceMaxFp16}, {Reduce_Min, ReduceMinFp16}, + {Reduce_Prod, ReduceProdFp16}, {Reduce_SumSquare, ReduceSumFp16}, + {Reduce_ASum, ReduceSumFp16}, {Reduce_L2, ReduceL2NormFp16}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceF16Compute); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce_f16->compute_ = func_list[i]; + return; + } + } +} + +void HandleReduceF16ASumAndSumSquare(KernelBase *base) { + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float16_t *data = (float16_t *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int CalculateReduceF16CoeffOutput(KernelBase *base) { + TensorC *out_tensor = base->out_[OUTPUT_INDEX]; + int num = NNACLGetElementNum(out_tensor); + + float16_t *out_data = (float16_t *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)base->param_)->coeff; + } + return NNACL_OK; +} + +KernelBase *CreateReduceF16(OpParameter *param, int data_type) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)malloc(sizeof(ReduceF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce_f16); + memset(reduce_f16, 0, sizeof(ReduceF16Struct)); + + ReduceStruct *reduce = &reduce_f16->reduce_; + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + + reduce->handle_sum_square_ = HandleReduceF16ASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceF16CoeffOutput; + reduce->init_kernel_list_ = InitialReduceF16KernelList; + reduce->call_uint_ = CallReduceF16Unit; + + return (KernelBase *)reduce_f16; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat16, CreateReduceF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h new file mode 100644 index 00000000..df990afd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/reduce_f16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_REDUCE_F16_H_ +#define NNACL_KERNEL_F16_REDUCE_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/reduce.h" + +KernelBase *CreateReduceF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_REDUCE_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c new file mode 100644 index 00000000..cc748a0e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/f16/stack_f16.h" +#include "nnacl_c/fp16/cast_fp16.h" +#include "nnacl_c/tensor_c_utils.h" + +void *StackF16InitBuffer(KernelBase *base, TensorC *t, bool init) { + if (init == false) { + return t->data_; + } + + int ele_num = NNACLGetElementNum(t); + void *f16_buffer = base->env_->Alloc(base->env_->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(f16_buffer); + Float32ToFloat16(t->data_, f16_buffer, ele_num); + return f16_buffer; +} + +int StackF16InitMallocFlags(StackF16Struct *stack_f16) { + KernelBase *base = (KernelBase *)stack_f16; + stack_f16->init_ = base->env_->Alloc(base->env_->allocator_, (base->in_size_ + base->out_size_) * sizeof(bool)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->init_); + + for (size_t i = 0; i < base->in_size_; ++i) { + stack_f16->init_[i] = base->in_[i]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[i] = StackF16InitBuffer(base, base->in_[i], stack_f16->init_[i]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[i]); + } + stack_f16->init_[base->in_size_] = base->out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[base->in_size_] = + StackF16InitBuffer(base, base->out_[OUTPUT_INDEX], stack_f16->init_[base->in_size_]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[base->in_size_]); + return NNACL_OK; +} + +void StackF16FreeBuffer(StackF16Struct *stack_f16) { + if (stack_f16->init_[stack_f16->stack_.base_.in_size_]) { + /* output transfer */ + Float16ToFloat32((float16_t *)stack_f16->stack_.buffers_[stack_f16->stack_.base_.in_size_], + (float *)stack_f16->stack_.base_.out_[OUTPUT_INDEX]->data_, + NNACLGetElementNum(stack_f16->stack_.base_.out_[OUTPUT_INDEX])); + } + + for (size_t i = 0; i < (stack_f16->stack_.base_.in_size_ + stack_f16->stack_.base_.out_size_); ++i) { + if (stack_f16->init_[i]) { + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->stack_.buffers_[i]); + } + stack_f16->stack_.buffers_[i] = NULL; + } + + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->init_); + stack_f16->init_ = NULL; +} + +int StackF16Compute(KernelBase *self) { + StackF16Struct *stack_f16 = (StackF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack_f16); + + int ret = StackF16InitMallocFlags(stack_f16); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); + StackF16FreeBuffer(stack_f16); + return ret; +} + +KernelBase *CreateStackF16(OpParameter *param, int data_type) { + StackF16Struct *stack_f16 = (StackF16Struct *)malloc(sizeof(StackF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack_f16); + StackStruct *stack = &stack_f16->stack_; + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackF16Compute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat16, CreateStackF16) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h new file mode 100644 index 00000000..640f04fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/f16/stack_f16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_STACK_F16_H_ +#define NNACL_KERNEL_F16_STACK_F16_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/stack.h" + +typedef struct StackF16Struct { + StackStruct stack_; + bool *init_; +} StackF16Struct; + +KernelBase *CreateStackF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_STACK_F16_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c new file mode 100644 index 00000000..098704ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.c @@ -0,0 +1,102 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/fill_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/fill_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/fill_fp16.h" +#endif + +int FillResize(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + fill->base_.thread_nr_ = fill->base_.UpdateThread( + TC_PTYPE(PrimType_Fill), 0, 1, NNACLGetSize(fill->base_.out_[OUTPUT_INDEX]), fill->base_.thread_nr_); + + NNACL_CHECK_NULL_RETURN_ERR(fill->base_.out_[OUTPUT_INDEX]); + fill->data_size_ = (int)NNACLGetElementNum(fill->base_.out_[OUTPUT_INDEX]); + fill->thread_sz_count_ = MSMIN(fill->base_.thread_nr_, fill->data_size_); + if (fill->thread_sz_count_ != 0) { + fill->thread_sz_stride_ = UP_DIV(fill->data_size_, fill->thread_sz_count_); + } + return NNACL_OK; +} + +int FillImpl(void *cdata, int task_id, float l, float r) { + FillStruct *fill = (FillStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fill); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, fill->thread_sz_stride_, NNACL_ERR); + int size = MSMIN(fill->thread_sz_stride_, fill->data_size_ - task_id * fill->thread_sz_stride_); + NNACL_CHECK_FALSE(size <= 0, NNACL_OK); + int offset = task_id * fill->thread_sz_stride_; + int ret = NNACL_OK; + switch (fill->base_.in_[FIRST_INPUT]->data_type_) { +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + ret = FillFp16((float16_t *)fill->out_ptr_ + offset, size, ((float16_t *)fill->src_data_)[FIRST_INPUT]); + break; +#endif + case kNumberTypeFloat32: + ret = FillFp32((float *)fill->out_ptr_ + offset, size, ((float *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeInt32: + ret = FillInt32((int *)fill->out_ptr_ + offset, size, ((int *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeBool: + ret = FillBool((bool *)fill->out_ptr_ + offset, size, ((bool *)fill->src_data_)[FIRST_INPUT]); + break; + default: + return NNACL_FILL_DATA_TYPE_INVALID; + } + return ret; +} + +int FillCompute(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + + fill->src_data_ = (void *)fill->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->src_data_); + fill->out_ptr_ = (void *)fill->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->out_ptr_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, FillImpl, fill, fill->base_.thread_nr_); +} + +KernelBase *CreateFill(OpParameter *param, int data_type) { + FillStruct *fill = (FillStruct *)malloc(sizeof(FillStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fill); + fill->base_.Prepare = DefaultPrepare2In1Out; + fill->base_.Resize = FillResize; + fill->base_.Release = DefaultRelease; + fill->base_.Compute = FillCompute; + return (KernelBase *)fill; +} + +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat16, CreateFill); + +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat16, CreateFill); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h new file mode 100644 index 00000000..3cb44136 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fill.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_FILL_H_ +#define NNACL_KERNEL_FILL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct FillStruct { + KernelBase base_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + void *src_data_; + void *out_ptr_; + int thread_count_; +} FillStruct; + +KernelBase *CreateFill(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FILL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c new file mode 100644 index 00000000..215aa8e2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fullconnection.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +int FullConnectionPrepare(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + if (matmul->a_const_ || matmul->infer_shape_) { + int *a_shape = self->in_[FIRST_INPUT]->shape_; + matmul->compute_.row_ = a_shape[0]; + matmul->compute_.deep_ = a_shape[1]; + } + + if (matmul->b_const_ || matmul->infer_shape_) { + int *b_shape = self->in_[SECOND_INPUT]->shape_; + matmul->compute_.col_ = b_shape[0]; + matmul->compute_.deep_ = b_shape[1]; + } + + matmul->batch_ = 1; + matmul->a_batch_ = 1; + matmul->b_batch_ = 1; + + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + param->a_transpose_ = false; + param->b_transpose_ = true; + + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + return MatmulBasePrepare(self); +} + +int FullConnectionResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + NNACL_CHECK_TRUE_RET(self->out_[0]->shape_size_ > 0, NNACL_ERR); + + int row = 1; + for (size_t i = 0; i < self->out_[0]->shape_size_ - 1; ++i) { + row *= (self->out_[OUTPUT_INDEX]->shape_)[i]; + } + matmul->compute_.row_ = row; + matmul->compute_.col_ = (self->out_[OUTPUT_INDEX]->shape_)[self->out_[0]->shape_size_ - 1]; + matmul->compute_.deep_ = self->in_[SECOND_INPUT]->shape_[SECOND_INPUT]; + + return MatmulBaseResize(self); +} + +KernelBase *CreateFullconnection(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = FullConnectionPrepare; + kernel->Resize = FullConnectionResize; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_FullConnection, kNumberTypeFloat32, CreateFullconnection); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h new file mode 100644 index 00000000..a54116d9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fullconnection.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FULLCONNECTION_H_ +#define NNACL_KERNEL_FULLCONNECTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateFullconnection(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FULLCONNECTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c new file mode 100644 index 00000000..5e49da11 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.c @@ -0,0 +1,327 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/fused_batch_norm.h" +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/scale_fp16.h" +#include "nnacl_c/fp16/batchnorm_fp16.h" +#endif + +int FusedBatchNormInitScaleParam(FusedBatchNormStruct *fused_batch_norm) { + ScaleStruct *scale = &fused_batch_norm->scale_param_; + scale->base_.thread_nr_ = fused_batch_norm->bn_.base_.thread_nr_; + + scale->axis_ = kNHWC_C; + TensorC *in_tensor = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ != DIMENSION_4D) { + return NNACL_FUSED_BATCH_NORM_NO_CHANGE; + } + + scale->outer_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= in_tensor->shape_[i]; + } + scale->axis_size_ = in_tensor->shape_[Index3]; + scale->inner_size_ = 1; + return NNACL_OK; +} + +void FusedBatchNormCalculateScaleF32(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + float *fp32_scale_origin = (float *)scale_data; + float *fp32_var_origin = (float *)var_data; + float *fp32_bias_origin = (float *)bias_data; + float *fp32_mean_origin = (float *)mean_data; + + float *fp32_scale = (float *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp32_scale[i] = fp32_scale_origin[i] / sqrtf(fp32_var_origin[i] + eps); + } + + float *fp32_offset = (float *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp32_offset[i] = fp32_bias_origin[i] - fp32_mean_origin[i] * fp32_scale[i]; + } +} + +void FusedBatchNormCalculateScaleF16(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { +#ifdef ENABLE_FP16 + float16_t *fp16_scale_origin = (float16_t *)scale_data; + float16_t *fp16_var_origin = (float16_t *)var_data; + float16_t *fp16_bias_origin = (float16_t *)bias_data; + float16_t *fp16_mean_origin = (float16_t *)mean_data; + + float16_t *fp16_scale = (float16_t *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp16_scale[i] = fp16_scale_origin[i] / sqrtf(fp16_var_origin[i] + eps); + } + + float16_t *fp16_offset = (float16_t *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp16_offset[i] = fp16_bias_origin[i] - fp16_mean_origin[i] * fp16_scale[i]; + } +#endif +} + +void FusedBatchNormRunFp16(FusedBatchNormStruct *fused_batch_norm, int task_id) { +#ifdef ENABLE_FP16 + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScaleFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp16((float16_t *)in_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, (float16_t *)fused_batch_norm->bn_.mean_, + (float16_t *)fused_batch_norm->bn_.variance_, &fused_batch_norm->bn_, task_id, + fused_batch_norm->bn_.base_.thread_nr_, (float16_t *)out_data); + } +#endif +} + +int FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct *fused_batch_norm, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + int ret = FusedBatchNormInitScaleParam(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + + // new scale: -scale / sqrt(variance + eps) + // new bias: -scale * mean / sqrt(variance + eps) + bias + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormCalculateScaleF16(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } else { + FusedBatchNormCalculateScaleF32(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } + + fused_batch_norm->is_scale_ = true; + return NNACL_OK; +} + +int FusedBatchNormInitConstTensor(FusedBatchNormStruct *fused_batch_norm) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *variance_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + + if (!fused_batch_norm->bn_.base_.train_session_) { + int ret = FusedBatchNormBatchnorm2Scale( + fused_batch_norm, (float *)scale_tensor->data_, (float *)offset_tensor->data_, (float *)mean_tensor->data_, + (float *)variance_tensor->data_, fused_batch_norm->bn_.epsilon_, NNACLGetElementNum(scale_tensor)); + if (ret == NNACL_OK) { + return NNACL_OK; + } else { + fused_batch_norm->bn_.base_.Release(&fused_batch_norm->bn_.base_); + if (ret != NNACL_FUSED_BATCH_NORM_NO_CHANGE) { + return NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED; + } + } + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + fused_batch_norm->bn_.mean_ = env->Alloc(env->allocator_, NNACLGetSize(mean_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.mean_); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + fused_batch_norm->bn_.variance_ = env->Alloc(env->allocator_, NNACLGetSize(variance_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.variance_); + (void)memcpy(fused_batch_norm->bn_.variance_, variance_tensor->data_, NNACLGetSize(variance_tensor)); + return NNACL_OK; +} + +void FusedBatchNormRunFp32(FusedBatchNormStruct *fused_batch_norm, int task_id) { + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScale((float *)in_data, (float *)out_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp32((float *)in_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + (float *)fused_batch_norm->bn_.mean_, (float *)fused_batch_norm->bn_.variance_, + &fused_batch_norm->bn_, task_id, fused_batch_norm->bn_.base_.thread_nr_, (float *)out_data); + } +} + +int FusedBatchNormRun(void *cdata, int task_id, float l, float r) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormRunFp16(fused_batch_norm, task_id); + } else if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat32) { + FusedBatchNormRunFp32(fused_batch_norm, task_id); + } + return NNACL_OK; +} + +int FusedBatchNormTrainComputeInit(FusedBatchNormStruct *fbn) { + if (fbn->bn_.base_.out_size_ < Num5) { + return NNACL_OK; + } + + TensorC *out_scale = fbn->bn_.base_.out_[SECOND_INPUT]; + TensorC *out_offset = fbn->bn_.base_.out_[THIRD_INPUT]; + TensorC *out_mean = fbn->bn_.base_.out_[FOURTH_INPUT]; + TensorC *out_var = fbn->bn_.base_.out_[FIFTH_INPUT]; + + void *current_mean = fbn->bn_.mean_; + void *current_var = fbn->bn_.variance_; + + bool schema_trained = ((BatchNormParameter *)fbn->bn_.base_.param_)->is_training_; + if (fbn->train_mode_ && schema_trained && fbn->bn_.base_.in_size_ >= Num5) { + TensorC *in_tensor = fbn->bn_.base_.in_[FIRST_INPUT]; + TensorC *scale_tensor = fbn->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fbn->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fbn->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fbn->bn_.base_.in_[FIFTH_INPUT]; + if (in_tensor->data_ == NULL || scale_tensor->data_ == NULL || offset_tensor->data_ == NULL || + mean_tensor->data_ == NULL || var_tensor->data_ == NULL) { + return NNACL_FUSED_BATCH_TRAIN_DATA_INVALID; + } + + memset(current_mean, 0, NNACLGetSize(mean_tensor)); + memset(current_var, 0, NNACLGetSize(var_tensor)); + + bool isBatch2d = true; + if (fbn->bn_.base_.in_[FIRST_INPUT]->shape_size_ == Num2) isBatch2d = false; + + if (fbn->bn_.data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + FusedBatchNormFp16MeanVar((float16_t *)in_tensor->data_, (float16_t *)current_mean, current_var, &fbn->bn_, + (float16_t *)mean_tensor->data_, (float16_t *)var_tensor->data_); +#endif + } else { + FusedBatchNormFp32MeanVar((float *)in_tensor->data_, (float *)current_mean, current_var, &fbn->bn_, + (float *)mean_tensor->data_, (float *)var_tensor->data_, isBatch2d); + } + + (void)memcpy(out_scale->data_, scale_tensor->data_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, offset_tensor->data_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + + // Copy to local variables + (void)memcpy(fbn->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fbn->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + + fbn->trained_ = true; // trained at least once + return NNACL_OK; + } + + if (fbn->bn_.base_.train_session_) { + (void)memcpy(out_scale->data_, fbn->scale_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, fbn->offset_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + } + + return NNACL_OK; +} + +int FusedBatchNormCompute(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = FusedBatchNormTrainComputeInit(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, FusedBatchNormRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int FusedBatchNormReSize(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = BatchNormFillParam(&fused_batch_norm->bn_); + if (ret != NNACL_OK) { + return ret; + } + + (void)self->Release(self); + + return FusedBatchNormInitConstTensor(fused_batch_norm); +} + +int FusedBatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < FIVE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + fused_batch_norm->bn_.momentum_ = ((BatchNormParameter *)self->param_)->momentum_; + fused_batch_norm->bn_.epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int FusedBatchNormRelease(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + (void)BatchNormRelease(&fused_batch_norm->bn_.base_); + + if (fused_batch_norm->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->scale_); + fused_batch_norm->scale_ = NULL; + } + if (fused_batch_norm->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->offset_); + fused_batch_norm->offset_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)malloc(sizeof(FusedBatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fused_batch_norm); + memset(fused_batch_norm, 0, sizeof(FusedBatchNormStruct)); + fused_batch_norm->bn_.data_type_ = data_type; + fused_batch_norm->bn_.base_.Prepare = FusedBatchNormPrepare; + fused_batch_norm->bn_.base_.Resize = FusedBatchNormReSize; + fused_batch_norm->bn_.base_.Release = FusedBatchNormRelease; + fused_batch_norm->bn_.base_.Compute = FusedBatchNormCompute; + return (KernelBase *)fused_batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat16, CreateFusedBatchNorm) +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat32, CreateFusedBatchNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h new file mode 100644 index 00000000..15193c9c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/fused_batch_norm.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FUSED_BATCH_NORM_H_ +#define NNACL_KERNEL_FUSED_BATCH_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/scale.h" + +typedef struct FusedBatchNormStruct { + BatchNormStruct bn_; + ScaleStruct scale_param_; + void *scale_; + void *offset_; + bool is_scale_; + bool trained_; + bool train_mode_; +} FusedBatchNormStruct; + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FUSED_BATCH_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c new file mode 100644 index 00000000..532640b2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.c @@ -0,0 +1,241 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define kGatherMinCostPerThread 16384 + +void GatherHandleCopy(GatherStruct *gather, int8_t **int8_in, int8_t **int8_out, int begin, int end, + int byte_in_stride) { + for (; begin < end; ++begin) { + int index = gather->indices_data_[begin]; + index = (index < 0 ? index + gather->limit_ : index); + if (index < 0 || index >= gather->limit_) { + memset(*int8_out, 0, gather->byte_inner_size_); + } else { + memcpy(*int8_out, *int8_in + index * gather->byte_inner_size_, gather->byte_inner_size_); + } + *int8_out += gather->byte_inner_size_; + } + *int8_in += byte_in_stride; +} + +int GatherRun(void *cdata, int task_id, float l, float r) { + GatherStruct *gather = (GatherStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= gather->block_infos_size_, NNACL_ERR); + + int8_t *int8_in = (int8_t *)(gather->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_in); + int8_t *int8_out = (int8_t *)(gather->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_out); + int begin_batch = gather->block_infos_[task_id].begin_batch_; + int begin_index = gather->block_infos_[task_id].begin_index_; + int end_batch = gather->block_infos_[task_id].end_batch_; + int end_index = gather->block_infos_[task_id].end_index_; + int64_t byte_in_stride = gather->limit_ * gather->byte_inner_size_; + int8_in += begin_batch * byte_in_stride; + int8_out += begin_batch * gather->indices_size_ * gather->byte_inner_size_ + begin_index * gather->byte_inner_size_; + if (begin_batch == end_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, end_index, byte_in_stride); + return NNACL_OK; + } + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, gather->indices_size_, byte_in_stride); + ++begin_batch; + for (; begin_batch < end_batch; ++begin_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, 0, gather->indices_size_, byte_in_stride); + } + GatherHandleCopy(gather, &int8_in, &int8_out, 0, end_index, byte_in_stride); + return NNACL_OK; +} + +int AssignGatherIndicesData(GatherStruct *gather, bool is_indices_int32) { + TensorC *indices_tensor = gather->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor->data_); + + if (is_indices_int32) { + gather->indices_data_ = (int *)(indices_tensor->data_); + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->indices_size_, (int)(sizeof(int)), NNACL_ERR); + gather->indices_data_ = + (int *)(gather->base_.env_->Alloc(gather->base_.env_->allocator_, gather->indices_size_ * sizeof(int))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather->indices_data_); + + switch (indices_tensor->data_type_) { + case kNumberTypeInt64: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((int64_t *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeFloat: + case kNumberTypeFloat32: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((float *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeBool: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((bool *)indices_tensor->data_)[i]; + } + break; + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int InitGatherDynamicStatus(GatherStruct *gather) { + int *in_shape = gather->base_.in_[FIRST_INPUT]->shape_; + int in_rank = (int)gather->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_TRUE_RET(gather->axis_ >= 0 && gather->axis_ < in_rank, NNACL_GATHER_AXIS_INVALID); + gather->limit_ = in_shape[gather->axis_]; + gather->outer_size_ = 1; + for (int i = 0; i < gather->axis_; ++i) { + gather->outer_size_ *= in_shape[i]; + } + gather->byte_inner_size_ = (int)DataTypeCSize(gather->base_.out_[OUTPUT_INDEX]->data_type_); + for (int i = gather->axis_ + 1; i < in_rank; ++i) { + gather->byte_inner_size_ *= in_shape[i]; + } + gather->indices_size_ = NNACLGetElementNum(gather->base_.in_[SECOND_INPUT]); + return NNACL_OK; +} + +void GatherUpdateThreadNumProcess(GatherStruct *gather) { + int all_bytes = NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]); + if (all_bytes <= kGatherMinCostPerThread) { + gather->base_.thread_nr_ = 1; + return; + } + + gather->base_.thread_nr_ = + gather->base_.UpdateThread(TC_PTYPE(PrimType_Gather), 0, gather->byte_inner_size_, + NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]), gather->base_.thread_nr_); + return; +} + +int ChooseGatherThreadCuttingStrategy(GatherStruct *gather) { + gather->block_infos_size_ = 0; + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + GatherUpdateThreadNumProcess(gather); + if (gather->base_.thread_nr_ > GATHER_BLOCK_INFOS_SIZE) { + gather->base_.thread_nr_ = GATHER_BLOCK_INFOS_SIZE; + } + + if (gather->base_.thread_nr_ == 1) { + gather->block_infos_[gather->block_infos_size_].begin_batch_ = 0; + gather->block_infos_[gather->block_infos_size_].begin_index_ = 0; + gather->block_infos_[gather->block_infos_size_].end_batch_ = gather->outer_size_; + gather->block_infos_[gather->block_infos_size_].end_index_ = 0; + gather->block_infos_size_++; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->outer_size_, gather->indices_size_, NNACL_ERR); + int total_block = gather->outer_size_ * gather->indices_size_; + int block_size = total_block / gather->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_size, gather->base_.thread_nr_, NNACL_ERR); + int remain_block = total_block - block_size * gather->base_.thread_nr_; + int start = 0; + while (start < total_block) { + GatherBlockBoundaryInfo block_boundary_info; + block_boundary_info.begin_batch_ = start / gather->indices_size_; + block_boundary_info.begin_index_ = start % gather->indices_size_; + start += block_size; + if (remain_block > 0) { + ++start; + --remain_block; + } + if (start >= total_block) { + start = total_block; + } + block_boundary_info.end_batch_ = start / gather->indices_size_; + block_boundary_info.end_index_ = start % gather->indices_size_; + gather->block_infos_[gather->block_infos_size_++] = block_boundary_info; + } + gather->base_.thread_nr_ = gather->block_infos_size_; + } + + return NNACL_OK; +} + +int GatherResize(KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + int status = InitGatherDynamicStatus(gather); + NNACL_CHECK_FALSE(status != NNACL_OK, status); + + return ChooseGatherThreadCuttingStrategy(gather); +} + +int GatherPrepare(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_GATHER_INPUT_TENSOR_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_GATHER_OUTPUT_TENSOR_INVALID); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]->data_); + gather->axis_ = *((int *)self->in_[THIRD_INPUT]->data_); + return NNACL_OK; +} + +int GatherCompute(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + + bool is_indices_int32 = self->in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32; + int ret = AssignGatherIndicesData(gather, is_indices_int32); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, GatherRun, gather, gather->base_.thread_nr_); + + if (!is_indices_int32) { + self->env_->Free(self->env_->allocator_, gather->indices_data_); + gather->indices_data_ = NULL; + } + return ret; +} + +KernelBase *CreateGather(OpParameter *param, int data_type) { + GatherStruct *gather = (GatherStruct *)malloc(sizeof(GatherStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather); + gather->indices_data_ = NULL; + gather->block_infos_size_ = 0; + gather->base_.Prepare = GatherPrepare; + gather->base_.Resize = GatherResize; + gather->base_.Release = DefaultRelease; + gather->base_.Compute = GatherCompute; + return (KernelBase *)gather; +} + +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat16, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeInt32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeBool, CreateGather) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h new file mode 100644 index 00000000..9a95133f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather.h @@ -0,0 +1,46 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_H_ +#define NNACL_KERNEL_GATHER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +#define GATHER_BLOCK_INFOS_SIZE 32 + +typedef struct GatherBlockBoundaryInfo { + int64_t begin_batch_; + int64_t begin_index_; + int64_t end_batch_; + int64_t end_index_; +} GatherBlockBoundaryInfo; + +typedef struct GatherStruct { + KernelBase base_; + int axis_; + int limit_; + int outer_size_; + int indices_size_; + int byte_inner_size_; + int block_infos_size_; + int *indices_data_; + GatherBlockBoundaryInfo block_infos_[GATHER_BLOCK_INFOS_SIZE]; +} GatherStruct; + +KernelBase *CreateGather(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c new file mode 100644 index 00000000..f80909bd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.c @@ -0,0 +1,124 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either gather_dress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather_d.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/gather_d_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +typedef struct GatherDStru { + KernelBase base; +} GatherDStru; + +int GatherDPrepare(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + param->axis_ = ((int *)(gather_d->base.in_[1]->data_))[0]; + return NNACL_OK; +} + +int GatherDResize(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + int input_rank = (int)gather_d->base.in_[0]->shape_size_; + NNACL_CHECK_FALSE(param->axis_ >= input_rank || param->axis_ < -input_rank, NNACL_GATHER_D_AXIS_INVALID); + + if (param->axis_ < 0) { + param->axis_ = param->axis_ + input_rank; + } + return NNACL_OK; +} + +int GatherDCompute(struct KernelBase *self) { + GatherDStru *gather_d_stru = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d_stru); + GatherParameter *param = (GatherParameter *)gather_d_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = gather_d_stru->base.in_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = gather_d_stru->base.out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + const void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *index_data = gather_d_stru->base.in_[2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(index_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + size_t input_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < input->shape_size_; i++) { + input_shape[i] = input->shape_[i]; + } + size_t output_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < output->shape_size_; i++) { + output_shape[i] = output->shape_[i]; + } + + int input_dtype = input->data_type_; + int index_dtype = gather_d_stru->base.in_[THIRD_INPUT]->data_type_; + int status = NNACL_ERR; + if (index_dtype == kNumberTypeInt32) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int32_t, (float *)output_data, (float *)input_data, (int32_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int32_t, (float16_t *)output_data, (float16_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int32_t, (int32_t *)output_data, (int32_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } else if (index_dtype == kNumberTypeInt64) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int64_t, (float *)output_data, (float *)input_data, (int64_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int64_t, (float16_t *)output_data, (float16_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int64_t, (int32_t *)output_data, (int32_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } + return status; +} + +KernelBase *CreateGatherD(OpParameter *param, int data_type) { + GatherDStru *gather_d = (GatherDStru *)malloc(sizeof(GatherDStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_d); + gather_d->base.Prepare = GatherDPrepare; + gather_d->base.Resize = GatherDResize; + gather_d->base.Release = DefaultRelease; + gather_d->base.Compute = GatherDCompute; + return (KernelBase *)gather_d; +} + +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeInt32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat16, CreateGatherD); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h new file mode 100644 index 00000000..eb6f69e5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_d.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_D_H_ +#define NNACL_KERNEL_GATHER_D_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateGatherD(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_D_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c new file mode 100644 index 00000000..c1b29047 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.c @@ -0,0 +1,168 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/gather_nd.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/nnacl_common.h" + +int GatherNdInitOffset(GatherNdStruct *gather_nd) { + TensorC *input_tensor = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *indices_tensor = gather_nd->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + if (indices_tensor->shape_size_ < 1) { + return NNACL_GATHER_ND_INDICES_RANK_INVALID; + } + + int in_rank = input_tensor->shape_size_; + int idx_lastshape = indices_tensor->shape_[indices_tensor->shape_size_ - 1]; + if (idx_lastshape > in_rank) { + return NNACL_GATHER_ND_INDICES_SHAPE_INVALID; + } + + gather_nd->area_ = 1; + for (int i = idx_lastshape; i < input_tensor->shape_size_; ++i) { + gather_nd->area_ *= input_tensor->shape_[i]; + } + + int in_stride[MAX_SHAPE_SIZE] = {0}; + in_stride[in_rank - 1] = 1; + for (int i = in_rank - 2; i >= 0; --i) { + in_stride[i] = input_tensor->shape_[i + 1] * in_stride[i + 1]; + } + + int idx_stride = idx_lastshape; + (void)memset(gather_nd->in_offset_, 0, gather_nd->count_ * sizeof(int)); + + if (indices_tensor->data_type_ == kNumberTypeInt || indices_tensor->data_type_ == kNumberTypeInt32) { + int32_t *indices_ptr = (int32_t *)indices_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_ptr); + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else if (indices_tensor->data_type_ == kNumberTypeInt64) { + int64_t *indices_ptr = (int64_t *)indices_tensor->data_; + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else { + return NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int GatherNdRun(void *cdata, int task_id, float l, float r) { + GatherNdStruct *gather_nd = (GatherNdStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *input = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, gather_nd->thread_stride_, NNACL_ERR); + int count = NNACL_MIN(gather_nd->thread_stride_, gather_nd->count_ - task_id * gather_nd->thread_stride_); + if (count <= 0) { + return NNACL_OK; + } + + int offset = task_id * gather_nd->thread_stride_; + int dtype_len = DataTypeCSize(input->data_type_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(offset, gather_nd->area_, NNACL_ERR); + int8_t *out_ptr = (int8_t *)gather_nd->out_ptr_ + offset * gather_nd->area_ * dtype_len; + return GatherNd(gather_nd->in_ptr_, out_ptr, gather_nd->in_offset_ + offset, gather_nd->area_, count, dtype_len); +} + +int GatherNdCompute(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + gather_nd->in_ptr_ = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->in_ptr_); + + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + gather_nd->out_ptr_ = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->out_ptr_); + + int ret = GatherNdInitOffset(gather_nd); + if (ret != NNACL_OK) { + return ret; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, GatherNdRun, self, self->thread_nr_); +} + +int GatherNdRelease(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + if (gather_nd->in_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, gather_nd->in_offset_); + gather_nd->in_offset_ = NULL; + } + return NNACL_OK; +} + +int GatherNdResize(KernelBase *self) { + (void)self->Release; + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *indices_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + gather_nd->count_ = 1; + for (int i = 0; i < indices_tensor->shape_size_ - 1; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather_nd->count_, indices_tensor->shape_[i], NNACL_ERR); + gather_nd->count_ *= indices_tensor->shape_[i]; + } + + int min_count = INT32_MAX / sizeof(int); + if (gather_nd->count_ >= min_count) { + return NNACL_GATHER_ND_COUNT_INVALID; + } + + gather_nd->in_offset_ = self->env_->Alloc(self->env_->allocator_, gather_nd->count_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather_nd->in_offset_); + + gather_nd->base_.thread_nr_ = NNACL_MIN(gather_nd->base_.thread_nr_, gather_nd->count_); + if (gather_nd->base_.thread_nr_ != 0) { + gather_nd->thread_stride_ = UP_DIV(gather_nd->count_, gather_nd->base_.thread_nr_); + } + return NNACL_OK; +} + +KernelBase *CreateGatherNd(OpParameter *param, int data_type) { + GatherNdStruct *gather_nd = (GatherNdStruct *)malloc(sizeof(GatherNdStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_nd); + memset(gather_nd, 0, sizeof(GatherNdStruct)); + + gather_nd->base_.Prepare = DefaultPrepare2In1Out; + gather_nd->base_.Resize = GatherNdResize; + gather_nd->base_.Compute = GatherNdCompute; + gather_nd->base_.Release = GatherNdRelease; + return (KernelBase *)gather_nd; +} + +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeBool, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeInt32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat16, CreateGatherNd); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h new file mode 100644 index 00000000..bb1f87f8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/gather_nd.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_ND_H_ +#define NNACL_KERNEL_GATHER_ND_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int *in_offset_; + int count_; + int area_; + int thread_stride_; + void *in_ptr_; + void *out_ptr_; +} GatherNdStruct; + +KernelBase *CreateGatherNd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_ND_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c new file mode 100644 index 00000000..9a68f98c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.c @@ -0,0 +1,419 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/group_convolution.h" +#include "nnacl_c/kernel/convolution_delegate.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int GroupConvBasePrepare(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *sub_conv = group_conv->group_convs_[i]; + NNACL_CHECK_NULL_RETURN_ERR(sub_conv); + int ret = sub_conv->Prepare(sub_conv); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvCreatorNewInputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *in_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(in_tensor); + in_tensor->format_ = Format_NHWC; + in_tensor->category_ = VarTensor; + in_tensor->data_type_ = group_conv->data_type_; + in_tensor->shape_size_ = DIMENSION_4D; + in_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->in_[FIRST_INPUT] = in_tensor; + return NNACL_OK; +} + +int GroupConvCreatorNewOutputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *out_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_tensor); + out_tensor->format_ = Format_NHWC; + out_tensor->category_ = VarTensor; + out_tensor->data_type_ = group_conv->data_type_; + out_tensor->shape_size_ = DIMENSION_4D; + out_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->out_[OUTPUT_INDEX] = out_tensor; + return NNACL_OK; +} + +TensorC *CreateConstTensor(const TensorC *tensor, const int *shape, const int shape_size, const int index) { + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + TensorC *new_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(new_tensor); + new_tensor->data_type_ = tensor->data_type_; + new_tensor->format_ = Format_NHWC; + new_tensor->category_ = ConstTensor; + new_tensor->shape_size_ = shape_size; + memcpy(new_tensor->shape_, shape, shape_size * sizeof(int)); + + int size = NNACLGetSize(new_tensor); + if (size <= 0) { + free(new_tensor); + return NULL; + } + + void *data = malloc(size); + if (data == NULL) { + free(new_tensor); + return NULL; + } + new_tensor->data_ = data; + + uint8_t *new_tensor_data = (uint8_t *)tensor->data_ + index * size; + memcpy(new_tensor->data_, new_tensor_data, size); + return new_tensor; +} + +int GroupConvCreatorNewConstTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + TensorC *origin_weight = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + int shape[] = {group_conv->sub_out_c_, NNACLGetHeight(origin_weight), NNACLGetWidth(origin_weight), + group_conv->sub_in_c_}; + TensorC *weight_tensor = CreateConstTensor(origin_weight, shape, DIMENSION_4D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + new_conv->in_[SECOND_INPUT] = weight_tensor; + + if (group_conv->conv_base_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_weight = group_conv->conv_base_.base_.in_[THIRD_INPUT]; + TensorC *bias_tensor = CreateConstTensor(bias_weight, &group_conv->sub_out_c_, DIMENSION_1D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(bias_tensor); + new_conv->in_[THIRD_INPUT] = bias_tensor; + } + return NNACL_OK; +} + +int GroupConvCreatorSetShapeOfTensors(GroupConvolutionStruct *group_conv) { + ConvParameter *origin_conv_param = (ConvParameter *)group_conv->conv_base_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(origin_conv_param); + ConvParameter *new_conv_param = &group_conv->new_conv_param_; + NNACL_CHECK_NULL_RETURN_ERR(new_conv_param); + memcpy(new_conv_param, origin_conv_param, sizeof(ConvParameter)); + + TensorC *weight_tensor = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_FALSE(origin_conv_param->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + NNACL_CHECK_FALSE(weight_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_h_ != NNACLGetHeight(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_w_ != NNACLGetWidth(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + ConvComputeParam *compute = &group_conv->conv_base_.compute_; + group_conv->ori_in_c_ = compute->in_c_; + group_conv->ori_out_c_ = compute->out_c_; + group_conv->sub_in_c_ = compute->in_c_ / group_conv->group_; + group_conv->sub_out_c_ = compute->out_c_ / group_conv->group_; + + new_conv_param->input_channel_ = group_conv->sub_in_c_; + new_conv_param->output_channel_ = group_conv->sub_out_c_; + new_conv_param->group_ = origin_conv_param->group_; + + return NNACL_OK; +} + +int GroupConvSetSubConvInfo(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_NULL_RETURN_ERR(new_conv); + + ConvolutionBaseStruct *sub_conv = (ConvolutionBaseStruct *)new_conv; + (void)ConvBaseUpdateParamInfo(&sub_conv->compute_, &group_conv->new_conv_param_); + + sub_conv->infershape_done_ = group_conv->conv_base_.infershape_done_; + sub_conv->shaing_manager_ = group_conv->conv_base_.shaing_manager_; + sub_conv->get_sharing_weight_ = group_conv->conv_base_.get_sharing_weight_; + sub_conv->free_sharing_weight_ = group_conv->conv_base_.free_sharing_weight_; + sub_conv->is_sharing_pack_ = group_conv->conv_base_.is_sharing_pack_; + + new_conv->env_ = group_conv->conv_base_.base_.env_; + new_conv->param_ = &group_conv->new_conv_param_.op_parameter_; + new_conv->thread_nr_ = group_conv->conv_base_.base_.thread_nr_; + new_conv->train_session_ = group_conv->conv_base_.base_.train_session_; + new_conv->UpdateThread = group_conv->conv_base_.base_.UpdateThread; + new_conv->in_size_ = group_conv->conv_base_.base_.in_size_; + new_conv->out_size_ = group_conv->conv_base_.base_.out_size_; + + new_conv->in_ = (TensorC **)malloc(new_conv->in_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->in_); + memset(new_conv->in_, 0, new_conv->in_size_ * sizeof(TensorC *)); + new_conv->out_ = (TensorC **)malloc(new_conv->out_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->out_); + memset(new_conv->out_, 0, new_conv->out_size_ * sizeof(TensorC *)); + + // create new input for each group + int ret = GroupConvCreatorNewInputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // const tensor + ret = GroupConvCreatorNewConstTensor(group_conv, new_conv, group_id); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // create new output tensor + ret = GroupConvCreatorNewOutputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + return NNACL_OK; +} + +int GroupConvConcatOutputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.out_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.out_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_out_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_out_src_ + begin_plane * group_conv->sub_out_c_; + float *dst_ptr = group_conv->sub_out_dst_ + begin_plane * group_conv->ori_out_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_out_c_ * sizeof(float)); + src_ptr += group_conv->sub_out_c_; + dst_ptr += group_conv->ori_out_c_; + } + return NNACL_OK; +} + +int GroupConvPostConcat(GroupConvolutionStruct *group_conv, int group_id) { + group_conv->sub_out_src_ = (float *)group_conv->group_convs_[group_id]->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_src_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_out_c_, NNACL_ERR); + group_conv->sub_out_dst_ = (float *)(group_conv->origin_output_data_) + group_id * group_conv->sub_out_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvConcatOutputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +int GroupConvSeparateInputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.in_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.in_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->ori_in_c_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_in_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_in_src_ + begin_plane * group_conv->ori_in_c_; + float *dst_ptr = group_conv->sub_in_dst_ + begin_plane * group_conv->sub_in_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_in_c_ * sizeof(float)); + src_ptr += group_conv->ori_in_c_; + dst_ptr += group_conv->sub_in_c_; + } + + return NNACL_OK; +} + +int GroupConvSeparateInput(GroupConvolutionStruct *group_conv, int group_id) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_in_c_, NNACL_ERR); + + group_conv->sub_in_src_ = (float *)(group_conv->origin_input_data_) + group_id * group_conv->sub_in_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_src_); + group_conv->sub_in_dst_ = (float *)(group_conv->group_convs_[group_id]->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvSeparateInputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +void GroupConvUpdateShape(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; i++) { + TensorC *in_tensor = group_conv->conv_base_.base_.in_[FIRST_INPUT]; + int in_shape[] = {NNACLGetBatch(in_tensor), NNACLGetHeight(in_tensor), NNACLGetWidth(in_tensor), + group_conv->sub_in_c_}; + memcpy(group_conv->group_convs_[i]->in_[FIRST_INPUT]->shape_, in_shape, DIMENSION_4D * sizeof(float)); + + TensorC *out_tensor = group_conv->conv_base_.base_.out_[OUTPUT_INDEX]; + int out_shape[] = {NNACLGetBatch(out_tensor), NNACLGetHeight(out_tensor), NNACLGetWidth(out_tensor), + group_conv->sub_out_c_}; + memcpy(group_conv->group_convs_[i]->out_[OUTPUT_INDEX]->shape_, out_shape, DIMENSION_4D * sizeof(float)); + } + return; +} + +int GroupConvolutionResize(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + (void)ConvBaseUpdateComputeInfo(&group_conv->conv_base_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + + GroupConvUpdateShape(group_conv); + + for (int i = 0; i < group_conv->group_; ++i) { + group_conv->group_convs_[i]->thread_nr_ = self->thread_nr_; + int ret = group_conv->group_convs_[i]->Resize(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvolutionCompute(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + group_conv->origin_input_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_input_data_); + group_conv->origin_output_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_output_data_); + + for (int i = 0; i < group_conv->group_; ++i) { + // first, malloc data for sub_kernel's tensors. + TensorC *sub_kernel_in_tensor = group_conv->group_convs_[i]->in_[FIRST_INPUT]; + sub_kernel_in_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_in_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_in_tensor->data_); + + TensorC *sub_kernel_out_tensor = group_conv->group_convs_[i]->out_[OUTPUT_INDEX]; + sub_kernel_out_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_out_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_out_tensor->data_); + + // second, separate group conv input into several parts. This step must be in runtime stage. + int ret = GroupConvSeparateInput(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // sun kernels run + ret = group_conv->group_convs_[i]->Compute(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + + // post process, concat all outputs of sub-kernels into one output + ret = GroupConvPostConcat(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // Free data + self->env_->Free(self->env_->allocator_, sub_kernel_in_tensor->data_); + sub_kernel_in_tensor->data_ = NULL; + self->env_->Free(self->env_->allocator_, sub_kernel_out_tensor->data_); + sub_kernel_out_tensor->data_ = NULL; + } + return NNACL_OK; +} + +int GroupConvolutionPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_FALSE(group_conv->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + + GroupConvCreatorSetShapeOfTensors(group_conv); + + group_conv->group_convs_ = (KernelBase **)malloc(group_conv->group_ * sizeof(KernelBase *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(group_conv->group_convs_); + memset(group_conv->group_convs_, 0, group_conv->group_ * sizeof(KernelBase *)); + + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *new_conv = CreateConvlutionDelegate(&group_conv->new_conv_param_); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv); + group_conv->group_convs_[i] = new_conv; + + int ret = GroupConvSetSubConvInfo(group_conv, new_conv, i); + if (ret != NNACL_OK) { + return ret; + } + } + return GroupConvBasePrepare(group_conv); +} + +void GroupConvReleaseSubConv(KernelBase *current_conv) { + (void)current_conv->Release(current_conv); + + if (current_conv->in_ != NULL) { + for (int j = 0; j < current_conv->in_size_; j++) { + if (NNACLIsConst(current_conv->in_[j])) { + free(current_conv->in_[j]->data_); + current_conv->in_[j]->data_ = NULL; + } + if (current_conv->in_[j] != NULL) { + free(current_conv->in_[j]); + current_conv->in_[j] = NULL; + } + } + free(current_conv->in_); + current_conv->in_ = NULL; + } + + if (current_conv->out_ != NULL) { + for (int j = 0; j < current_conv->out_size_; j++) { + if (current_conv->out_[j] != NULL) { + free(current_conv->out_[j]); + current_conv->out_[j] = NULL; + } + } + free(current_conv->out_); + current_conv->out_ = NULL; + } +} + +int GroupConvolutionRelease(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (group_conv->group_convs_ != NULL) { + for (int i = 0; i < conv_param->group_; i++) { + if (group_conv->group_convs_[i] != NULL) { + GroupConvReleaseSubConv(group_conv->group_convs_[i]); + free(group_conv->group_convs_[i]); + group_conv->group_convs_[i] = NULL; + } + } + free(group_conv->group_convs_); + group_conv->group_convs_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)malloc(sizeof(GroupConvolutionStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(group_conv); + memset(group_conv, 0, sizeof(GroupConvolutionStruct)); + + group_conv->data_type_ = data_type; + group_conv->group_ = conv_param->group_; + group_conv->conv_base_.base_.Compute = GroupConvolutionCompute; + group_conv->conv_base_.base_.Resize = GroupConvolutionResize; + group_conv->conv_base_.base_.Prepare = GroupConvolutionPrepare; + group_conv->conv_base_.base_.Release = GroupConvolutionRelease; + return (KernelBase *)group_conv; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h new file mode 100644 index 00000000..4d061b6e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_convolution.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_CONVOLUTION_H_ +#define NNACL_KERNEL_GROUP_CONVOLUTION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/kernel/convolution_base.h" + +typedef struct GroupConvolutionStruct { + ConvolutionBaseStruct conv_base_; + KernelBase **group_convs_; + ConvParameter new_conv_param_; + TypeIdC data_type_; + int group_; + + void *origin_input_data_; + void *origin_output_data_; + + float *sub_in_src_; + float *sub_in_dst_; + float *sub_out_src_; + float *sub_out_dst_; + + int sub_in_c_; + int ori_in_c_; + int sub_out_c_; + int ori_out_c_; +} GroupConvolutionStruct; + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type); + +#endif // NNACL_KERNEL_GROUP_CONVOLUTION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c new file mode 100644 index 00000000..aabbfb94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.c @@ -0,0 +1,122 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/group_norm.h" +#include "nnacl_c/fp32/group_norm_fp32.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" + +int GroupNormResize(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2, NNACL_TENSOR_SIZE_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + self->Release(self); + + TensorC *in0 = self->in_[0]; + NNACL_CHECK_FALSE(in0->shape_size_ < C1NUM, NNACL_GROUP_NORM_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(in0->format_ != Format_NCHW, NNACL_GROUP_NORM_FORMAT_INVALID); + + param->unit_ = NNACLGetHeight(in0) * NNACLGetWidth(in0); + param->batch_ = NNACLGetBatch(in0); + param->channel_ = NNACLGetChannel(in0); + return self->Prepare(self); +} + +int GroupNormPrepare(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(param->num_groups_ < 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->channel_ % param->num_groups_, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->num_groups_ == 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + + size_t mean_var_elem_num = param->num_groups_; + param->mean_ = malloc(mean_var_elem_num * sizeof(float)); + param->variance_ = malloc(mean_var_elem_num * sizeof(float)); + if (param->mean_ == NULL || param->variance_ == NULL) { + self->Release(self); + return NNACL_MALLOC_BUFFER_FAILED; + } + return NNACL_OK; +} + +int GroupNormRelease(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + if (param->mean_ != NULL) { + free(param->mean_); + param->mean_ = NULL; + } + if (param->variance_ != NULL) { + free(param->variance_); + param->variance_ = NULL; + } + + return NNACL_OK; +} + +int GroupNormImpl(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + GroupNormStru *groupnorm_stru = (GroupNormStru *)param; + GroupNormParameter *groupnorm_param = (GroupNormParameter *)groupnorm_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param); + + const void *input_data = groupnorm_stru->base.in_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *scale_data = groupnorm_stru->base.in_[C1NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale_data); + const void *offset_data = groupnorm_stru->base.in_[C2NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(offset_data); + void *output_data = groupnorm_stru->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->mean_); + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->variance_); + + int ret = GroupNormFp32(input_data, scale_data, offset_data, groupnorm_param->mean_, groupnorm_param->variance_, + groupnorm_param, task_id, output_data); + + return ret; +} + +int GroupNormCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, GroupNormImpl, self, self->param_->thread_num_); +} + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type) { + GroupNormStru *groupnorm = (GroupNormStru *)malloc(sizeof(GroupNormStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(groupnorm); + + groupnorm->base.Prepare = GroupNormPrepare; + groupnorm->base.Resize = GroupNormResize; + groupnorm->base.Release = GroupNormRelease; + groupnorm->base.Compute = GroupNormCompute; + + return (void *)groupnorm; +} + +REG_KERNEL_CREATOR(PrimType_GroupNormFusion, kNumberTypeFloat32, CreateGroupNorm); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h new file mode 100644 index 00000000..79ff15ab --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/group_norm.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_NORM_H_ +#define NNACL_KERNEL_GROUP_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/kernel.h" + +typedef struct GroupNormStru { + KernelBase base; +} GroupNormStru; + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GROUP_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c new file mode 100644 index 00000000..0913d714 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.c @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/init_exec_env.h" + +#define NNACLMaxAllocSize (2000 * 1024 * 1024) +ExecEnv nnacl_default_env; + +void *NNACLDefaultAlloc(void *allocator, size_t sz) { + if (sz == 0 || sz > NNACLMaxAllocSize) { + return NULL; + } + return malloc(sz); +} + +void NNACLDefaultFree(void *allocator, void *ptr) { return free(ptr); } + +int NNACLDefaultParallelLunch(void *threadPool, void *task, void *param, int taskNr) { + int (*function)(void *cdata, int task_id, float l, float r) = task; + int ret = 0; + for (int i = 0; i < taskNr; i++) { + ret += function(param, i, 0, 1); + } + return ret == NNACL_OK ? NNACL_OK : NNACL_ERR; +} + +void InitDefaultExecEnv(void) { + nnacl_default_env.Free = NNACLDefaultFree; + nnacl_default_env.Alloc = NNACLDefaultAlloc; + nnacl_default_env.ParallelLaunch = NNACLDefaultParallelLunch; +} + +void CheckExecEnv(KernelBase *base) { + if (base->env_ == NULL) { + base->env_ = &nnacl_default_env; + } + return; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h new file mode 100644 index 00000000..ee417051 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_exec_env.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_EXEC_ENV_H_ +#define NNACL_KERNEL_INIT_EXEC_ENV_H_ + +#include "nnacl_c/kernel.h" + +#ifndef _MSC_VER +__attribute__((constructor(103))) void InitDefaultExecEnv(void); +#endif + +void CheckExecEnv(KernelBase *base); + +#endif // NNACL_KERNEL_INIT_EXEC_ENV_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c new file mode 100644 index 00000000..43a29485 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.c @@ -0,0 +1,357 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/init_vs_kernels.h" +#include "nnacl_c/kernel/activation.h" +#include "nnacl_c/kernel/arithmetic.h" +#include "nnacl_c/kernel/arithmetic_compare.h" +#include "nnacl_c/kernel/arithmetic_self.h" +#include "nnacl_c/kernel/arg_min_max.h" +#include "nnacl_c/kernel/addn.h" +#include "nnacl_c/kernel/biasadd.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/clip.h" +#include "nnacl_c/kernel/concat.h" +#include "nnacl_c/kernel/crop.h" +#include "nnacl_c/kernel/crop_and_resize.h" +#include "nnacl_c/kernel/exp.h" +#include "nnacl_c/kernel/depth_to_space.h" +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/kernel/fused_batch_norm.h" +#include "nnacl_c/kernel/fullconnection.h" +#include "nnacl_c/kernel/gather.h" +#include "nnacl_c/kernel/gather_d.h" +#include "nnacl_c/kernel/gather_nd.h" +#include "nnacl_c/kernel/group_norm.h" +#include "nnacl_c/kernel/log_softmax.h" +#include "nnacl_c/kernel/local_response_norm.h" +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/matmul.h" +#include "nnacl_c/kernel/non_max_suppression.h" +#include "nnacl_c/kernel/non_zero.h" +#include "nnacl_c/kernel/nllloss.h" +#include "nnacl_c/kernel/prior_box.h" +#include "nnacl_c/kernel/prelu.h" +#include "nnacl_c/kernel/pad.h" +#include "nnacl_c/kernel/pow.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/reverse.h" +#include "nnacl_c/kernel/range.h" +#include "nnacl_c/kernel/rank.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/kernel/shape.h" +#include "nnacl_c/kernel/reduce.h" +#include "nnacl_c/kernel/ragged_range.h" +#include "nnacl_c/kernel/stack.h" +#include "nnacl_c/kernel/strided_slice.h" +#include "nnacl_c/kernel/softmax.h" +#include "nnacl_c/kernel/size.h" +#include "nnacl_c/kernel/splice.h" +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/kernel/tril.h" +#include "nnacl_c/kernel/triu.h" +#include "nnacl_c/kernel/transpose.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/unique.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/kernel/f16/arithmetic_f16.h" +#include "nnacl_c/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl_c/kernel/f16/concat_f16.h" +#include "nnacl_c/kernel/f16/reduce_f16.h" +#include "nnacl_c/kernel/f16/stack_f16.h" +#endif + +void InitVSKernelF16(KernelCreator **creators) { +#ifdef ENABLE_FP16 + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat16)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat16)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat16)] = CreateConcatF16; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat16)] = CreateCrop; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat16)] = CreateDepthToSpace; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat16)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat16)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat16)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat16)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat16)] = CreateLayerNorm; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat16)] = CreateLogSoftmax; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePad; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePRelu; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePow; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateReduceF16; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat16)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat16)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat16)] = CreateSoftmax; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat16)] = CreateStackF16; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat16)] = CreateStridedSlice; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat16)] = CreateSplice; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat16)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat16)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat16)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat16)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat16)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat16)] = CreateUnique; +#endif +} + +void InitVSKernelA(KernelCreator **creators) { + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Abs][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat32)] = CreateActivation; + creators[PrimType_Activation][REGIST_DT(kNumberTypeUInt32)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat32)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_BiasAdd][REGIST_DT(kNumberTypeFloat32)] = CreateBiasAdd; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat32)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt32)] = CreateClip; + creators[PrimType_Concat][REGIST_DT(kNumberTypeBool)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeInt32)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat32)] = CreateConcat; + creators[PrimType_Crop][REGIST_DT(kNumberTypeInt32)] = CreateCrop; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat32)] = CreateCrop; + creators[PrimType_CropAndResize][REGIST_DT(kNumberTypeFloat32)] = CreateCropAndResize; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat32)] = CreateDepthToSpace; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Equal][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ExpFusion][REGIST_DT(kNumberTypeFloat32)] = CreateExp; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt8)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeBool)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeInt32)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat32)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FullConnection][REGIST_DT(kNumberTypeFloat32)] = CreateFullconnection; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeInt32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeBool)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat32)] = CreateGatherD; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeInt32)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeBool)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeInt32)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat32)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Greater][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GroupNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateGroupNorm; +} + +void InitVSKernelI(KernelCreator **creators) { + creators[PrimType_IsFinite][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateLayerNorm; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Less][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat32)] = CreateLogSoftmax; + creators[PrimType_Log1p][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeBool)] = CreateArithmeticSelf; + creators[PrimType_LRN][REGIST_DT(kNumberTypeFloat32)] = CreateLocalResponseNorm; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MatMulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateMatmul; + creators[PrimType_Mod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Mod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_NLLLoss][REGIST_DT(kNumberTypeFloat32)] = CreateNLLLoss; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Neg][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt64)] = CreateArithmeticCompare; + creators[PrimType_NonZero][REGIST_DT(kNumberTypeBool)] = CreateNonZero; + creators[PrimType_NonMaxSuppression][REGIST_DT(kNumberTypeFloat32)] = CreateNonMaxSuppression; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePad; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeFloat32)] = CreatePriorBox; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeInt8)] = CreatePriorBox; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePow; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePRelu; +} + +void InitVSKernelR(KernelCreator **creators) { + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeInt32)] = CreateRaggedRange; + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeFloat32)] = CreateRaggedRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeInt32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat16)] = CreateRange; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeBool)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeInt32)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateReduce; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeInt32)] = CreateReverse; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeFloat32)] = CreateReverse; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat32)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeBool)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeUInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt64)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat32)] = CreateSoftmax; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat32)] = CreateStack; + creators[PrimType_Stack][REGIST_DT(kNumberTypeInt32)] = CreateStack; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt64)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt8)] = CreateStridedSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeInt32)] = CreateSize; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat32)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeInt32)] = CreateSlice; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateSlice; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat32)] = CreateSplice; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeInt32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeBool)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeUInt8)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeDouble)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeBool)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeDouble)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeBool)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat32)] = CreateTranspose; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeInt32)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt64)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeInt32)] = CreateUnique; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat32)] = CreateUnique; +} + +void init_vs_kernels(KernelCreator **creators) { + InitVSKernelA(creators); + InitVSKernelI(creators); + InitVSKernelR(creators); + InitVSKernelF16(creators); +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h new file mode 100644 index 00000000..1a8c9d53 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/init_vs_kernels.h @@ -0,0 +1,20 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_VS_KERNELS_H_ +#define NNACL_KERNEL_INIT_VS_KERNELS_H_ +#include "nnacl_c/kernel.h" +void init_vs_kernels(KernelCreator **creators); +#endif // NNACL_KERNEL_INIT_VS_KERNELS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c new file mode 100644 index 00000000..de70e04a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.c @@ -0,0 +1,130 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/layer_norm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/layer_norm_fp16.h" +#endif + +int LayerNormRun(void *cdata, int task_id, float l, float r) { + LayerNormStruct *ln = (LayerNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(ln); + if (ln->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return LayerNormFp16(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +#endif + } + return LayerNorm(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +} + +int LayerNormResize(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + LayerNormComputeParam *compute = &layer_norm->compute_; + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + if (compute->begin_norm_axis_ < 0) { + compute->begin_norm_axis_ = compute->begin_norm_axis_ + (int)input->shape_size_; + } + + if (compute->begin_params_axis_ < 0) { + compute->begin_params_axis_ = compute->begin_params_axis_ + (int)input->shape_size_; + } + + compute->norm_outer_size_ = 1; + for (int i = 0; i < compute->begin_norm_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_outer_size_, input->shape_[i], NNACL_ERR); + compute->norm_outer_size_ *= input->shape_[i]; + } + + compute->norm_inner_size_ = 1; + for (size_t i = compute->begin_norm_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_inner_size_, input->shape_[i], NNACL_ERR); + compute->norm_inner_size_ *= input->shape_[i]; + } + + compute->params_outer_size_ = 1; + for (int i = 0; i < compute->begin_params_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_outer_size_, input->shape_[i], NNACL_ERR); + compute->params_outer_size_ *= input->shape_[i]; + } + + compute->params_inner_size_ = 1; + for (size_t i = compute->begin_params_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_inner_size_, input->shape_[i], NNACL_ERR); + compute->params_inner_size_ *= input->shape_[i]; + } + + int out_num = NNACLGetElementNum(self->out_[OUTPUT_INDEX]); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_LayerNormFusion), compute->norm_inner_size_, + compute->norm_inner_size_, out_num, self->thread_nr_); + self->thread_nr_ = NNACL_MIN(compute->norm_outer_size_, self->thread_nr_); + return NNACL_OK; +} + +int LayerNormCompute(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + + layer_norm->src_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->src_data_); + layer_norm->gamma_data_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->gamma_data_); + layer_norm->beta_data_ = self->in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->beta_data_); + layer_norm->dst_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->dst_data_); + + if (layer_norm->base_.out_size_ == THREE_TENSOR) { + layer_norm->mean_data_ = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->mean_data_); + layer_norm->var_data_ = self->out_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->var_data_); + } else if (layer_norm->base_.out_size_ != ONE_TENSOR) { + return NNACL_LAYER_NORM_OUTPUT_NUM_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, LayerNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type) { + LayerNormStruct *layer_norm = (LayerNormStruct *)malloc(sizeof(LayerNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(layer_norm); + memset(layer_norm, 0, sizeof(LayerNormStruct)); + layer_norm->data_type_ = data_type; + + LayerNormParameter *layer_norm_param = (LayerNormParameter *)param; + layer_norm->compute_.epsilon_ = layer_norm_param->epsilon_; + layer_norm->compute_.elementwise_affine_ = layer_norm_param->elementwise_affine_; + layer_norm->compute_.begin_norm_axis_ = layer_norm_param->begin_norm_axis_; + layer_norm->compute_.begin_params_axis_ = layer_norm_param->begin_params_axis_; + + layer_norm->base_.Prepare = DefaultPrepare3In1Out; + layer_norm->base_.Release = DefaultRelease; + layer_norm->base_.Resize = LayerNormResize; + layer_norm->base_.Compute = LayerNormCompute; + return (KernelBase *)layer_norm; +} + +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat16, CreateLayerNorm) +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat32, CreateLayerNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h new file mode 100644 index 00000000..5b561a65 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/layer_norm.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LAYER_NORM_H_ +#define NNACL_KERNEL_LAYER_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct LayerNormComputeParam { + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; + int norm_inner_size_; + int norm_outer_size_; + int params_inner_size_; + int params_outer_size_; +} LayerNormComputeParam; + +typedef struct LayerNormStruct { + KernelBase base_; + LayerNormComputeParam compute_; + int data_type_; + void *src_data_; + void *dst_data_; + void *gamma_data_; + void *beta_data_; + void *mean_data_; + void *var_data_; +} LayerNormStruct; + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LAYER_NORM_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c new file mode 100644 index 00000000..a16b31c5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/local_response_norm.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/local_response_norm_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int LocalResponseNormRun(void *cdata, int task_id, float l, float r) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(lrn); + LocalResponseNormParameter *param = (LocalResponseNormParameter *)lrn->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = lrn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = lrn->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_4D, NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->depth_radius_ <= 0, NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID); + + float *input_ptr = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + int batch = NNACLGetBatch(input); + int height = NNACLGetHeight(input); + int width = NNACLGetWidth(input); + int channel = NNACLGetChannel(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(batch, width, NNACL_ERR); + int size_bw = batch * width; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(size_bw, height, NNACL_ERR); + int outer_size = size_bw * height; + int stride = UP_DIV(outer_size, lrn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int start = stride * task_id; + int count = MSMIN(stride, outer_size - start); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(start, channel, NNACL_ERR); + input_ptr += start * channel; + output_ptr += start * channel; + + return LocalResponseNorm(input_ptr, count, channel, output_ptr, param); +} + +int LrnCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LocalResponseNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)malloc(sizeof(LocalResponseNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(lrn); + memset(lrn, 0, sizeof(LocalResponseNormStruct)); + + lrn->base_.Prepare = DefaultPrepare1In1Out; + lrn->base_.Release = DefaultRelease; + lrn->base_.Resize = DefaultResize; + lrn->base_.Compute = LrnCompute; + return (KernelBase *)lrn; +} + +REG_KERNEL_CREATOR(PrimType_LRN, kNumberTypeFloat32, CreateLocalResponseNorm) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h new file mode 100644 index 00000000..0b3ebf73 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/local_response_norm.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ +#define NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct LocalResponseNormStruct { + KernelBase base_; +} LocalResponseNormStruct; + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c new file mode 100644 index 00000000..3c311766 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.c @@ -0,0 +1,120 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/log_softmax.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/log_softmax_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/log_softmax_fp16.h" +#endif + +int LogSoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = log_softmax->softmax_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + void *tmp_ptr = log_softmax->softmax_.sum_data_; + NNACL_CHECK_NULL_RETURN_ERR(tmp_ptr); + + int unit = UP_DIV(log_softmax->softmax_.out_plane_size_, log_softmax->softmax_.base_.thread_nr_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, log_softmax->softmax_.out_plane_size_); + int channel = in->shape_[log_softmax->softmax_.axis_]; + int offset = begin * channel; + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxLastAxisFp16((const float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, + (float16_t *)tmp_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + LogSoftmaxLastAxis((const float *)input_ptr + offset, (float *)output_ptr + offset, (float *)tmp_ptr + offset, + end - begin, channel); + return NNACL_OK; +} + +int LogSoftmaxResize(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + int ret = InitSoftmaxParam(&log_softmax->softmax_); + if (ret != NNACL_OK) { + return ret; + } + + if (log_softmax->softmax_.in_plane_size_ == 1 && log_softmax->softmax_.sum_data_ == NULL) { + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + SoftmaxStruct *softmax = &log_softmax->softmax_; + + int sum_data_size = softmax->in_plane_size_ * softmax->out_plane_size_ * in->shape_[softmax->axis_]; + softmax->sum_data_ = self->env_->Alloc(self->env_->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int LogSoftmaxCompute(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + if (log_softmax->softmax_.in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LogSoftmaxLastAxisRun, self, self->thread_nr_); + } + + TensorC *in = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(log_softmax->softmax_.sum_data_); + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxFp16((const float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)log_softmax->softmax_.sum_data_, + in->shape_, in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; + } +#endif + LogSoftmax((const float *)input_ptr, (float *)output_ptr, (float *)log_softmax->softmax_.sum_data_, in->shape_, + in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; +} + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)malloc(sizeof(LogSoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(log_softmax); + memset(log_softmax, 0, sizeof(LogSoftmaxStruct)); + + log_softmax->softmax_.sum_data_ = NULL; + log_softmax->softmax_.data_type_ = data_type; + log_softmax->softmax_.base_.Prepare = DefaultPrepare1In1Out; + log_softmax->softmax_.base_.Release = SoftmaxRelease; + log_softmax->softmax_.base_.Resize = LogSoftmaxResize; + log_softmax->softmax_.base_.Compute = LogSoftmaxCompute; + return (KernelBase *)log_softmax; +} + +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat32, CreateLogSoftmax) +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat16, CreateLogSoftmax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h new file mode 100644 index 00000000..65abf3c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/log_softmax.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOG_SOFTMAX_H_ +#define NNACL_KERNEL_LOG_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/kernel/softmax.h" + +typedef struct LogSoftmaxStruct { + SoftmaxStruct softmax_; +} LogSoftmaxStruct; + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c new file mode 100644 index 00000000..8644c41a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.c @@ -0,0 +1,176 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/kernel/matmul_create.h" + +void MatmulInitShapeA(MatmulStruct *matmul) { + int *a_shape = matmul->base_.in_[kInputIndex]->shape_; + size_t a_shape_size = matmul->base_.in_[kInputIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(a_shape_size >= C2NUM); + for (size_t i = 0; i < a_shape_size - C2NUM; ++i) { + batch *= a_shape[i]; + } + matmul->a_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_ = param->a_transpose_ ? a_shape[a_shape_size - 1] : a_shape[a_shape_size - C2NUM]; + matmul->compute_.deep_ = param->a_transpose_ ? a_shape[a_shape_size - C2NUM] : a_shape[a_shape_size - 1]; +} + +void MatmulInitShapeB(MatmulStruct *matmul) { + int *b_shape = matmul->base_.in_[kWeightIndex]->shape_; + size_t b_shape_size = matmul->base_.in_[kWeightIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(b_shape_size >= C2NUM); + for (size_t i = 0; i < b_shape_size - C2NUM; ++i) { + batch *= b_shape[i]; + } + matmul->b_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.col_ = param->b_transpose_ ? b_shape[b_shape_size - C2NUM] : b_shape[b_shape_size - 1]; + matmul->compute_.deep_ = param->b_transpose_ ? b_shape[b_shape_size - 1] : b_shape[b_shape_size - C2NUM]; +} + +int MatmulInitBroadcastParams(MatmulStruct *matmul) { + TensorC *a = matmul->base_.in_[FIRST_INPUT]; + TensorC *b = matmul->base_.in_[SECOND_INPUT]; + + int max_dim_size = (int)NNACL_MAX(a->shape_size_, b->shape_size_); + max_dim_size = NNACL_MAX(max_dim_size, COMM_SHAPE_SIZE); + + int a_shape[MAX_SHAPE_SIZE] = {0}; + int index = max_dim_size - 1; + for (int i = (int)a->shape_size_ - 1; i >= 0; i--) { + a_shape[index--] = a->shape_[i]; + } + for (; index >= 0;) { + a_shape[index--] = 1; + } + + int b_shape[MAX_SHAPE_SIZE] = {0}; + index = max_dim_size - 1; + for (int i = (int)b->shape_size_ - 1; i >= 0; i--) { + b_shape[index--] = b->shape_[i]; + } + for (; index >= 0;) { + b_shape[index--] = 1; + } + + int batch_sizes[MAX_SHAPE_SIZE] = {0}; + int a_batch_sizes[MAX_SHAPE_SIZE] = {0}; + int b_batch_sizes[MAX_SHAPE_SIZE] = {0}; + for (int i = max_dim_size - Num3; i >= 0; --i) { + if (max_dim_size - Num3 == i) { + batch_sizes[i] = NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_shape[i]; + b_batch_sizes[i] = b_shape[i]; + } else { + batch_sizes[i] = batch_sizes[i + 1] * NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i]; + b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i]; + } + } + + int out_batch = 1; + for (int i = 0; i < max_dim_size - Num2; ++i) { + int max_v = NNACL_MAX(a_shape[i], b_shape[i]); + int min_v = NNACL_MIN(a_shape[i], b_shape[i]) > 0 ? NNACL_MIN(a_shape[i], b_shape[i]) : 1; + out_batch *= max_v; + if ((max_v != min_v) && ((max_v % min_v) != 0)) { + return NNACL_ERR; + } + } + matmul->batch_ = out_batch; + + MatmulBaseFreeBatchOffset(matmul); + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + for (int i = 0; i < matmul->batch_; ++i) { + int delta = i; + int a_offset = 0; + int b_offset = 0; + for (int j = 0; j < max_dim_size - Num2; ++j) { + if (j > 0) { + delta = delta % batch_sizes[j]; + } + if (j >= (MAX_SHAPE_SIZE - 1)) { + return NNACL_ERR; + } + if (j < (max_dim_size - Num3)) { + a_offset += + (delta / batch_sizes[j + 1] * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1]; + b_offset += + (delta / batch_sizes[j + 1] * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1]; + } else { + a_offset += (delta * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + b_offset += (delta * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + } + } + matmul->a_offset_[i] = a_offset; + matmul->b_offset_[i] = b_offset; + } + return NNACL_OK; +} + +int MatmulPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->a_const_ || matmul->infer_shape_) { + MatmulInitShapeA(matmul); + } + + if (matmul->b_const_ || matmul->infer_shape_) { + MatmulInitShapeB(matmul); + } + + return MatmulBasePrepare(self); +} + +int MatmulResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulInitShapeA(matmul); + MatmulInitShapeB(matmul); + + int ret = MatmulInitBroadcastParams(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + return MatmulBaseResize(self); +} + +int MatmulRelease(KernelBase *self) { + MatmulBaseFreeBatchOffset((MatmulStruct *)self); + return MatmulBaseRelease(self); +} + +KernelBase *CreateMatmul(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = MatmulPrepare; + kernel->Resize = MatmulResize; + kernel->Release = MatmulRelease; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_MatMulFusion, kNumberTypeFloat32, CreateMatmul); diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h new file mode 100644 index 00000000..db1d3840 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_H_ +#define NNACL_KERNEL_MATMUL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmul(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c new file mode 100644 index 00000000..43bb65d8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/kernel/matmul_arm32.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" + +void MatmulARM32InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col4MajorParallel : RowMajor2Row4MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C4NUM; + matmul->compute_.col_min_unit_ = C4NUM; +} + +int MatmulARM32ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM32ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM32() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulARM32InitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulARM32ParallelRunByBatch; + matmul->parallel_run_by_oc_ = MatmulARM32ParallelRunByOC; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h new file mode 100644 index 00000000..a3bc34ac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm32.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM32_H_ +#define NNACL_KERNEL_MATMUL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulARM32(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM32_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c new file mode 100644 index 00000000..c61068a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.c @@ -0,0 +1,214 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/kernel/matmul_arm64.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32_opt.h" + +typedef struct MatrixAPack { + int64_t points_[MAX_THREAD_NUM]; + int64_t unit_num_; + int thread_; + int deep_; + int row_; + int col_; + MatrixInfo *matrix_a_; + float *src_ptr_; + bool a_transpose_; +} MatrixAPack; + +int MatmulARM64PackMatrixAImplOptPack(void *cdata, int task_id, float l, float r) { + MatrixAPack *pack = (MatrixAPack *)cdata; + int64_t start = pack->points_[task_id]; + int64_t end = pack->unit_num_; + if (task_id < pack->thread_ - 1) { + end = pack->points_[task_id + 1]; + } + + if (pack->a_transpose_) { + RowMajor2Row12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->deep_, pack->row_, start, end); + } else { + RowMajor2Col12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->row_, pack->deep_, start, end); + } + return NNACL_OK; +} + +int MatmulARM64PackMatrixAImplOpt(MatmulStruct *matmul) { + int64_t kPackAMinUnitNum = 1 << 13; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = matmul->matrix_a_.origin_ptr_ != NULL ? matmul->matrix_a_.origin_ptr_ + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + + MatrixAPack pack; + pack.src_ptr_ = src_ptr; + pack.matrix_a_ = &matmul->matrix_a_; + pack.deep_ = matmul->compute_.deep_; + pack.col_ = matmul->compute_.col_; + pack.row_ = matmul->compute_.row_; + pack.a_transpose_ = param->a_transpose_; + pack.unit_num_ = 0; + pack.unit_num_ = matmul->a_batch_ * UP_DIV(matmul->compute_.row_, C12NUM) * matmul->compute_.deep_; + pack.thread_ = MSMIN(matmul->base_.thread_nr_, UP_DIV(pack.unit_num_, kPackAMinUnitNum)); + if (pack.thread_ < 1) { + pack.thread_ = 1; + } + int64_t block_size = pack.unit_num_ / pack.thread_; + int64_t remain_size = pack.unit_num_ - block_size * pack.thread_; + int64_t start = 0; + size_t count = 0; + while (start < pack.unit_num_) { + pack.points_[count++] = start; + start += block_size; + if (remain_size > 0) { + ++start; + --remain_size; + } + } + pack.thread_ = count; + + if (pack.thread_ == 1) { + return MatmulARM64PackMatrixAImplOptPack(&pack, 0, 0, 1); + } + return matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulARM64PackMatrixAImplOptPack, &pack, + pack.thread_); +} + +bool MatmulARM64CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->batch_ >= matmul->base_.thread_nr_ || matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + return false; +} +void MatmulARM64InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->pack_opt_ = true; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; +} + +int MatmulARM64ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM64ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + GemmIsNotPackByRow(matmul->matrix_a_.pack_ptr_, matmul->matrix_b_.pack_ptr_, matmul->output_data_, + matmul->matrix_c_.pack_ptr_, start_row, end_row, matmul->compute_.deep_, param->act_type_); + return NNACL_OK; +} + +int MatmulARM64ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM64() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kMatmulFp32Arm64Cpu; + matmul->check_thread_cutting_by_row_ = MatmulARM64CheckThreadCuttingByRow; + matmul->init_global_varibale_ = MatmulARM64InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulARM64ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulARM64ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulARM64ParallelRunByBatch; + matmul->pack_matrix_a_impl_opt_ = MatmulARM64PackMatrixAImplOpt; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h new file mode 100644 index 00000000..35c938a7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM64_H_ +#define NNACL_KERNEL_MATMUL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulARM64(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM64_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c new file mode 100644 index 00000000..9ccae7ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.c @@ -0,0 +1,169 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl_c/kernel/matmul_avx.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void MatmulAVXInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C32NUM; + matmul->out_need_aligned_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col32MajorParallel : RowMajor2Row32MajorParallel; +} + +int MatmulAVXParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + ActType act = param->act_type_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_row = matmul->split_points_[task_id]; + int end_row = compute->row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * compute->deep_; + float *output = matmul->output_data_ + start_row * compute->col_align_; + if (compute->col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, compute->deep_, + param->act_type_); + } else { + MatMulAvxFp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + compute->deep_, compute->col_align_, compute->col_align_, row_num); + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, param->act_type_, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +bool MatmulAVXCheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C3NUM; + if (matmul->compute_.col_step_ < C16NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } else if (matmul->compute_.col_step_ < C24NUM) { + matmul->compute_.row_min_unit_ = C6NUM; + } else if (matmul->compute_.col_step_ < C32NUM) { + matmul->compute_.row_min_unit_ = C4NUM; + } + return MSMIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + MSMIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} + +KernelBase *CreateMatmulAVX() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulAVXInitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulAVXParallelRunByBatch; + matmul->parallel_run_by_row_ = MatmulAVXParallelRunByRow; + matmul->parallel_run_by_oc_ = MatmulAVXParallelRunByOC; + matmul->check_thread_cutting_by_row_ = MatmulAVXCheckThreadCuttingByRow; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h new file mode 100644 index 00000000..bb722473 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX_H_ +#define NNACL_KERNEL_MATMUL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulAVX(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c new file mode 100644 index 00000000..75e80eaf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.c @@ -0,0 +1,708 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl_c/kernel/matmul_avx512.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_fp32.h" +#include "nnacl_c/fp32/matmul_avx512_mask_fp32.h" + +#define MIN_CALC_COST 24576 /* 1 x 6 x 64x 64 */ + +void MatmulAVX512BatchRowThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // RowCut + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + + matmul->row_split_points_size_ = 0; + int row_split_point = 0; + while (row_split_point < matmul->compute_.row_) { + matmul->row_split_points_[matmul->row_split_points_size_++] = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + } + matmul->row_split_points_[matmul->row_split_points_size_] = matmul->compute_.row_; + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->row_split_points_size_; + } +} + +void MatmulAVX512BatchColThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + int thread_num_tmp = NNACL_MIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, thread_num_tmp); + int split_point = 0; + matmul->col_split_points_size_ = 0; + while (split_point < total_col_unit) { + matmul->col_split_points_[matmul->col_split_points_size_++] = split_point * matmul->compute_.col_min_unit_; + split_point += block_col_unit; + } + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->col_split_points_size_; + } +} + +void MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + int row_s = 0; + int row_e = matmul->compute_.row_; + int col_s = 0; + int col_e = matmul->compute_.col_; + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->compute_.block_col_unit_ = DOWN_DIV(total_col_unit, matmul->base_.thread_nr_); + matmul->col_split_points_size_ = 0; + matmul->col_split_points_[matmul->col_split_points_size_++] = 0; + if (matmul->compute_.block_col_unit_ > 0) { + int col_split_point = 0; + for (int i = 0; i < matmul->base_.thread_nr_; i++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_s; + matmul_slice.row_e_ = row_e; + matmul_slice.col_s_ = col_split_point * matmul->compute_.col_min_unit_; + col_split_point += matmul->compute_.block_col_unit_; + col_s = NNACL_MIN(col_split_point * matmul->compute_.col_min_unit_, matmul->compute_.col_step_); + matmul_slice.col_e_ = col_s; + matmul->matmul_slice_set_[i][matmul->matmul_slice_count_[i]++] = matmul_slice; + } + } + if (col_e - col_s <= 0) { + return; + } + + // RowColCut + int row_thread = 0; + int less_col_align = UP_ROUND(col_e - col_s, C16NUM); + bool use_colrowcut_flag = ((less_col_align / C64NUM) * C64NUM) == less_col_align; + bool use_rowcut_flag = matmul->compute_.row_ >= C6NUM * matmul->base_.thread_nr_ || col_e - col_s <= C64NUM; + if (use_rowcut_flag && !use_colrowcut_flag) { + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + int row_split_point = 0; + + for (row_thread = 0; row_thread < matmul->base_.thread_nr_ && row_split_point < matmul->compute_.row_; + row_thread++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_split_point; + + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + + matmul_slice.row_e_ = row_split_point; + matmul_slice.col_s_ = col_s; + matmul_slice.col_e_ = col_e; + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + } + } else { + int col_num = UP_DIV(col_e - col_s, C64NUM); + int row_num = NNACL_MIN(UP_DIV(matmul->base_.thread_nr_, col_num), (row_e - row_s)); + int tile_remaining = MSMAX(col_num * row_num - matmul->base_.thread_nr_, 0); + + NNACL_CHECK_ZERO_RETURN(row_num); + int row_step = (row_e - row_s) / row_num; + int row_remaining_tmp = (row_e - row_s) - row_step * row_num; + + int row_step_cut2 = (row_num == 1) ? row_step : (row_e - row_s) / (row_num - 1); + int row_remaining_cut2_tmp = (row_e - row_s) - row_step_cut2 * (row_num - 1); + + MatmulSlice matmul_slice; + for (int c = 0; c < col_num; c++) { + matmul_slice.col_s_ = col_s + c * C64NUM; + matmul_slice.col_e_ = NNACL_MIN(col_s + (c + 1) * C64NUM, matmul->compute_.col_); + int row_split_point = 0; + int row_remaining = row_remaining_tmp; + int row_remaining_cut2 = row_remaining_cut2_tmp; + if (c < col_num - tile_remaining) { + for (int r = 0; r < row_num; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } else { + for (int r = 0; r < row_num - 1; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step_cut2; + if (row_remaining_cut2 > 0) { + ++row_split_point; + --row_remaining_cut2; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } + } + } + if ((matmul->compute_.batch_stride_ == 0) && (matmul->compute_.block_col_unit_ == 0)) { + matmul->base_.thread_nr_ = row_thread; + } +} + +void MatmulAVX512GetThreadCuttingPolicy(MatmulStruct *matmul) { + size_t total_cost = (size_t)(matmul->batch_) * (size_t)(matmul->compute_.row_) * (size_t)(matmul->compute_.col_) * + (size_t)(matmul->compute_.deep_); + + // Thread Update + matmul->base_.thread_nr_ = MSMAX(NNACL_MIN((int)(total_cost / MIN_CALC_COST), matmul->base_.thread_nr_), C1NUM); + + if (matmul->compute_.deep_ < C128NUM) { + return MatmulBaseGetThreadCuttingPolicy(matmul); + } + + for (int i = 0; i < SPLIT_COUNT; i++) { + matmul->matmul_slice_count_[i] = 0; + } + if (matmul->compute_.col_ == 1 && !matmul->a_const_) { + MatmulAVX512BatchRowThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepdot_; + } else if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + MatmulAVX512BatchColThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->parallel_run_ = matmul->parallel_run_by_row1_deep1_gepdot_; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + matmul->gemm_not_pack_fun_ = Row1Deep1GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = Row1Deep1NoBiasGemmIsNotPack; + } + return; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepm_; + } else { + MatmulAVX512BatchColRowSliceThreadCut(matmul); + matmul->parallel_run_ = matmul->parallel_run_by_batch_col_row_gemm_; + } + return; +} + +bool MatmulAVX512CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C8NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C6NUM; + if (matmul->compute_.col_step_ < C48NUM) { + matmul->compute_.row_min_unit_ = C12NUM; + } else if (matmul->compute_.col_step_ < C64NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } + return NNACL_MIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + NNACL_MIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} +void MatmulAVX512InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col64MajorParallel : RowMajor2Row64MajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_b_.need_pack_ = true; + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C16NUM; + matmul->compute_.col_min_unit_ = C64NUM; + + if (matmul->compute_.row_ == 1) { + if (!matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + matmul->out_need_aligned_ = true; + } + } else if (matmul->compute_.col_ == 1) { + matmul->out_need_aligned_ = true; + } else { + matmul->out_need_aligned_ = false; + } + + if (matmul->compute_.deep_ >= C128NUM) { + matmul->out_need_aligned_ = false; + } +} +int MatmulAVX512InitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + if (compute->deep_ < C128NUM) { + return MatmulBaseInitParameter(matmul); + } + + matmul->init_global_varibale_(matmul); + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } else if (compute->row_ == 1 && !matmul->b_const_ && compute->col_ <= C128NUM) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * matmul->compute_.deep_; + float *output = matmul->output_data_ + start_row * matmul->compute_.col_step_; + if (matmul->compute_.col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, matmul->compute_.deep_, + param->act_type_); + } else { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_align_, matmul->compute_.col_align_, row_num); + } else { + MatMulMaskAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_, matmul->compute_.col_, row_num); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col); + } + } + return NNACL_OK; +} +int MatmulAVX512ParallelRunByGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matmul->col_split_points_[task_id + 1]; + int compute_oc = end_oc - start_oc; + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + if (compute_oc > 0) { + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, matrix_row); + } + } + } + + // by RowCut + int start_oc = matmul->col_split_points_[col_split_points_size]; + int end_oc = matrix_col; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int row_split_points_size = matmul->row_split_points_size_; + if (task_id >= row_split_points_size) { + return NNACL_OK; + } + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + + // by BatchCut + int start_batch = task_id * compute->batch_stride_; + int end_batch = start_batch + compute->batch_stride_; + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + int a_stride = compute->row_ * compute->deep_; + int b_stride = compute->deep_ * compute->col_; + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, compute->row_, compute->deep_, param->act_type_); + } + + // by RowCut + int split_points_size = matmul->row_split_points_size_; + if (task_id >= split_points_size) { + return NNACL_OK; + } + for (int index = matmul->base_.thread_nr_ * compute->batch_stride_; index < matmul->batch_; ++index) { + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + continue; + } + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride + start_row * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_ + start_row * compute->col_step_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, row_num, compute->deep_, param->act_type_); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + matmul->gemm_not_pack_fun_(a, b, c, bias, matrix_col, matrix_deep, param->act_type_); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + matmul->gemm_not_pack_fun_(a, b, c, bias, compute_oc, matrix_deep, param->act_type_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + MatmulSlice *matmul_slices = matmul->matmul_slice_set_[task_id]; + int slice_count = matmul->matmul_slice_count_[task_id]; + for (int s = 0; s < slice_count; s++) { + MatmulSlice matmul_slice = matmul_slices[s]; + + int start_oc = matmul_slice.col_s_; + int end_oc = matmul_slice.col_e_; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int start_row = matmul_slice.row_s_; + int end_row = matmul_slice.row_e_; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulAVX512() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->check_thread_cutting_by_row_ = MatmulAVX512CheckThreadCuttingByRow; + matmul->get_thread_cutting_policy_ = MatmulAVX512GetThreadCuttingPolicy; + matmul->init_parameter_ = MatmulAVX512InitParameter; + matmul->init_global_varibale_ = MatmulAVX512InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulAVX512ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulAVX512ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulAVX512ParallelRunByBatch; + matmul->parallel_run_by_gemm_ = MatmulAVX512ParallelRunByGEMM; + matmul->parallel_run_by_gepm_ = MatmulAVX512ParallelRunByGEPM; + matmul->parallel_run_by_gepdot_ = MatmulAVX512ParallelRunByGEPDOT; + matmul->parallel_run_by_batch_col_row_gemm_ = MatmulAVX512ParallelRunByBatchColRowGEMM; + matmul->parallel_run_by_row1_deep1_gepdot_ = MatmulAVX512ParallelRunByRow1Deep1GEPDOT; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h new file mode 100644 index 00000000..4233286c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_avx512.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX512_H_ +#define NNACL_KERNEL_MATMUL_AVX512_H_ +#ifdef ENABLE_AVX512 +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulAVX512(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX512_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c new file mode 100644 index 00000000..35917710 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.c @@ -0,0 +1,676 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/op_base.h" + +#define kNumDeepThreshold 512 + +int MatmulFp32Run(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return matmul->parallel_run_(matmul, task_id); +} + +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul) { + if (matmul->a_offset_ != NULL) { + free(matmul->a_offset_); + matmul->a_offset_ = NULL; + } + if (matmul->b_offset_ != NULL) { + free(matmul->b_offset_); + matmul->b_offset_ = NULL; + } +} + +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul) { + matmul->a_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->a_offset_); + memset(matmul->a_offset_, 0, matmul->batch_ * sizeof(int)); + + matmul->b_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->b_offset_); + memset(matmul->b_offset_, 0, matmul->batch_ * sizeof(int)); + return NNACL_OK; +} + +int MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start = task_id * compute->pack_b_stride_; + if (param->b_transpose_) { + int end = NNACL_MIN(matmul->compute_.col_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->col_, compute->deep_, start, end); + } else { + int end = NNACL_MIN(matmul->compute_.deep_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->deep_, compute->col_, start, end); + } + return NNACL_OK; +} + +int MatmulFp32PackMatrixBRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return MatmulBasePackMatrixBParallelRunByBatch(matmul, task_id); +} + +bool MatmulBaseCheckRowOptimalConditions(MatmulStruct *matmul) { + return matmul->compute_.row_ == 1 && + !(matmul->support_mul_batch_cut_by_row_ && (matmul->a_batch_ > 1 && matmul->b_batch_ == 1)); +} + +int MatmulBaseInitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + matmul->init_global_varibale_(matmul); + if (MatmulBaseCheckRowOptimalConditions(matmul)) { + compute->row_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->pack_opt_ = false; + if (!matmul->b_const_ && compute->col_ <= C128NUM) { + compute->col_tile_ = 1; + matmul->out_need_aligned_ = false; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + } + } + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulBasePackMatrixAImplOpt(MatmulStruct *matmul) { return NNACL_ERR; } + +int MatmulBasePackMatrixAImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = (matmul->matrix_a_.origin_ptr_ != NULL) ? (matmul->matrix_a_.origin_ptr_) + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_pack_fun_ != NULL, NNACL_ERR); + for (int i = 0; i < matmul->a_batch_; i++) { + const float *src = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.row_; + float *dst = matmul->matrix_a_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.row_align_; + if (param->a_transpose_) { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.deep_, matmul->compute_.row_, 0, matmul->compute_.deep_); + } else { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.row_, matmul->compute_.deep_, 0, matmul->compute_.row_); + } + } + return NNACL_OK; +} + +int MatmulBasePackMatrixBImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + + float *src_ptr = matmul->matrix_b_.origin_ptr_ != NULL ? matmul->matrix_b_.origin_ptr_ + : (float *)matmul->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_pack_fun_ != NULL, NNACL_ERR); + + for (int i = 0; i < matmul->b_batch_; i++) { + if (param->b_transpose_) { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.col_, matmul->base_.thread_nr_); + } else { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.deep_, matmul->base_.thread_nr_); + } + matmul->pack_b_src_ = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.col_; + matmul->pack_b_dst_ = matmul->matrix_b_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.col_align_; + int ret = matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulFp32PackMatrixBRun, matmul, + matmul->base_.thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + return NNACL_OK; +} + +int MatmulBasePackMatrixA(MatmulStruct *matmul) { + if (!matmul->a_const_) { + if (!matmul->matrix_a_.need_pack_) { + matmul->matrix_a_.pack_ptr_ = (float *)matmul->base_.in_[0]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.workspace_); + } else { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_a_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + } else { + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_a_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *a_matrix = matmul->base_.in_[FIRST_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, a_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + matmul->matrix_a_.pack_ptr_ = (float *)data; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + if (is_packed) { + return NNACL_OK; + } + } + if (matmul->pack_opt_) { + /* valid in arm64 */ + return matmul->pack_matrix_a_impl_opt_(matmul); + } + return matmul->pack_matrix_a_impl_(matmul); +} + +int MatmulBasePackMatrixB(MatmulStruct *matmul) { + if (!matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.workspace_) + matmul->matrix_a_.pack_size_; + } else { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_b_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + } else { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_b_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *b_matrix = matmul->base_.in_[SECOND_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, b_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(data); + matmul->matrix_b_.pack_ptr_ = (float *)data; + if (is_packed) { + return NNACL_OK; + } + } + return matmul->pack_matrix_b_impl_(matmul); +} + +int MatmulBaseBackupConstMatrix(MatmulStruct *matmul, MatrixInfo *matrix_info, int index) { + NNACL_CHECK_TRUE_RET(index < (int)matmul->base_.in_size_, NNACL_ERR); + size_t backup_size = (size_t)NNACLGetElementNum(matmul->base_.in_[index]) * sizeof(float); + NNACL_CHECK_TRUE_RET(backup_size > 0, NNACL_ERR); + matrix_info->origin_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, backup_size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matrix_info->origin_ptr_); + void *src_ptr = matmul->base_.in_[index]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_ptr); + (void)memcpy(matrix_info->origin_ptr_, src_ptr, backup_size); + matrix_info->origin_need_free_ = true; + return NNACL_OK; +} + +int MatmulBaseParallelRunByRow(MatmulStruct *matmul, int task_id) { return NNACL_ERR; } + +int MatmulBaseParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, param->act_type_, compute->deep_, compute->row_, compute->col_step_, compute->col_, + OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + for (int index = start_batch; index < end_batch; ++index) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * matmul->compute_.row_ * matmul->compute_.deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * matmul->compute_.deep_ * matmul->compute_.col_; + float *c = matmul->output_data_ + index * matmul->compute_.row_ * matmul->compute_.col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, matmul->compute_.row_, matmul->compute_.deep_, param->act_type_); + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct *matmul) { + int row_step = NNACL_MAX(matmul->compute_.row_num_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_num_ - row_step * matmul->base_.thread_nr_; + + int split_point = 0; + int count = 0; + while (split_point < matmul->compute_.row_num_) { + matmul->split_points_[count++] = split_point; + split_point += row_step; + if (row_remaining > 0) { + ++split_point; + --row_remaining; + } + } + matmul->base_.thread_nr_ = count; +} + +int MatmulBaseParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul) { + if (matmul->compute_.deep_ < kNumDeepThreshold) { + if (matmul->model_thread_nr_ != -1) { + matmul->base_.thread_nr_ = matmul->model_thread_nr_; + } + } + + if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && + (matmul->b_batch_ == matmul->a_batch_ || !matmul->support_mul_batch_cut_by_row_)) || + matmul->compute_.col_ == 1) { + matmul->compute_.batch_stride_ = UP_DIV(matmul->batch_, matmul->base_.thread_nr_); + matmul->parallel_run_ = matmul->parallel_run_by_batch_; + if (matmul->compute_.col_ != 1 || matmul->a_const_) { + return; + } + + matmul->parallel_run_ = matmul->parallel_run_not_pack_by_batch_; + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + if (matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } + } + return; + } else if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && matmul->b_batch_ == 1) || + matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } else { + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->base_.thread_nr_ = MSMIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, matmul->base_.thread_nr_); + + int count = 0; + int split_point = 0; + while (split_point < total_col_unit) { + matmul->split_points_[count++] = (split_point * matmul->compute_.col_min_unit_); + split_point += block_col_unit; + } + matmul->base_.thread_nr_ = count; + matmul->parallel_run_ = matmul->parallel_run_by_oc_; + } + return; +} + +int MatmulBasePackBiasMatrix(MatmulStruct *matmul) { + if (matmul->base_.in_size_ != FOURTH_INPUT) { + return NNACL_OK; + } + if (matmul->matrix_c_.has_packed_) { + NNACL_CHECK_FALSE(matmul->matrix_c_.pack_size_ < matmul->compute_.col_align_, NNACL_ERR); + return NNACL_OK; + } + TensorC *bias_tensor = matmul->base_.in_[THIRD_INPUT]; + float *bias_src = matmul->matrix_c_.origin_ptr_ != NULL ? matmul->matrix_c_.origin_ptr_ : (float *)bias_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(bias_src); + + int bias_num = NNACLGetElementNum(bias_tensor); + NNACL_CHECK_TRUE_RET(bias_num > 0 && matmul->compute_.col_align_ >= bias_num, NNACL_ERR); + + matmul->matrix_c_.pack_size_ = matmul->compute_.col_align_; + if (matmul->matrix_c_.pack_ptr_ == NULL) { + matmul->matrix_c_.pack_ptr_ = (float *)(malloc(matmul->matrix_c_.pack_size_ * sizeof(float))); + } + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_c_.pack_ptr_); + + if (bias_num == 1) { + for (int i = 0; i < matmul->matrix_c_.pack_size_; ++i) { + matmul->matrix_c_.pack_ptr_[i] = bias_src[0]; + } + } else { + (void)memcpy(matmul->matrix_c_.pack_ptr_, bias_src, bias_num * sizeof(float)); + (void)memset(matmul->matrix_c_.pack_ptr_ + bias_num, 0, (matmul->matrix_c_.pack_size_ - bias_num) * sizeof(float)); + } + if (matmul->matrix_c_.origin_need_free_) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_c_.origin_ptr_); + matmul->matrix_c_.origin_ptr_ = NULL; + matmul->matrix_c_.origin_need_free_ = false; + } + return NNACL_OK; +} + +int MatmulBaseInitTmpOutBuffer(MatmulStruct *matmul) { + if (matmul->out_need_aligned_) { + if (matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + } + // avx need to malloc dst aligned to C8NUM + // avx512 need to malloc dst aligned to C16NUM + int out_channel = matmul->compute_.col_; + NNACL_CHECK_ZERO_RETURN_ERR(matmul->compute_.col_tile_); + int oc_block_num = UP_DIV(out_channel, matmul->compute_.col_tile_); + int ele_num = matmul->batch_ * matmul->compute_.row_ * oc_block_num * matmul->compute_.col_tile_; + int data_size = ele_num * (int)sizeof(float); + matmul->output_data_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matmul->output_data_); + } + return NNACL_OK; +} + +void MatmulBaseInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + return; +} + +bool MatmulBaseCheckThreadCuttingByRow() { return false; } + +void MatmulBaseFreePackedMatrixA(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_a_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_a_.pack_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, matmul->matrix_a_.pack_ptr_); + } + matmul->matrix_a_.pack_ptr_ = NULL; +} + +void MatmulBaseFreePackedMatrixB(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_b_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_b_.pack_ptr_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_b_.pack_ptr_); + } + matmul->matrix_b_.pack_ptr_ = NULL; +} + +int MatmulBaseResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (self->train_session_) { + self->work_size_ = (matmul->matrix_a_.pack_size_ + matmul->matrix_b_.pack_size_) * (int)sizeof(float); + } + + matmul->get_thread_cutting_policy_(matmul); + if (!matmul->matrix_c_.has_packed_) { + ret = MatmulBasePackBiasMatrix(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (!matmul->bias_need_repack_) { + matmul->matrix_c_.has_packed_ = true; + } + } + ret = MatmulBaseInitTmpOutBuffer(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + return NNACL_OK; +} + +int MatmulBaseRelease(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulBaseFreeBatchOffset(matmul); + + if (matmul->out_need_aligned_ && matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + matmul->output_data_ = NULL; + } + if (matmul->matrix_c_.pack_ptr_ != NULL) { + free(matmul->matrix_c_.pack_ptr_); + matmul->matrix_c_.pack_ptr_ = NULL; + } + if (matmul->a_const_) { + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_a_.pack_ptr_); + } else { + free(matmul->matrix_a_.pack_ptr_); + } + } + if (matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + return NNACL_OK; + } + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_b_.pack_ptr_); + } else { + free(matmul->matrix_b_.pack_ptr_); + } + } + return NNACL_OK; +} + +int MatmulBasePrepare(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(matmul->base_.in_size_ < C2NUM, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.out_size_ < 1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + + if (matmul->base_.in_size_ == THREE_TENSOR) { + NNACL_CHECK_TRUE_RET(matmul->base_.in_[THIRD_INPUT]->data_type_ == kNumberTypeFloat32, NNACL_MATMUL_BIAS_INVALID); + } + + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE( + param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6, + NNACL_MATMUL_ACT_TYPE_INVALID); + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->a_const_) { + ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_a_.has_packed_ = true; + } + if (matmul->b_const_) { + ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_b_.has_packed_ = true; + } + + if (matmul->base_.in_size_ == THREE_TENSOR) { + /* deal with const bias */ + bool bias_const = NNACLIsConst(self->in_[THIRD_INPUT]); + if (!matmul->infer_shape_ && bias_const && !matmul->base_.train_session_ && matmul->matrix_c_.origin_ptr_ == NULL) { + ret = MatmulBaseBackupConstMatrix(matmul, &matmul->matrix_c_, THIRD_INPUT); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + } + return NNACL_OK; +} + +int MatmulBaseCompute(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + float *out_data = (float *)(matmul->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_FALSE(out_data == NULL, NNACL_ERR); + if (!matmul->out_need_aligned_) { + matmul->output_data_ = out_data; + } + + if (!matmul->a_const_) { + int ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + if (!matmul->b_const_) { + int ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MatmulFp32Run, self, self->thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->out_need_aligned_) { + PackNHWCXToNHWCFp32(matmul->output_data_, out_data, matmul->batch_, matmul->compute_.row_, matmul->compute_.col_, + matmul->compute_.col_tile_); + } else { + matmul->output_data_ = NULL; + } + if (!matmul->a_const_) { + MatmulBaseFreePackedMatrixA(self); + } + + if (!matmul->b_const_) { + MatmulBaseFreePackedMatrixB(self); + } + return NNACL_OK; +} + +void InitMatrixInfo(MatrixInfo *info) { + info->need_pack_ = false; + info->has_packed_ = false; + info->origin_need_free_ = false; + info->pack_size_ = -1; + info->origin_ptr_ = NULL; + info->pack_ptr_ = NULL; +} + +KernelBase *CreateMatmulBase() { + MatmulStruct *matmul = (MatmulStruct *)malloc(sizeof(MatmulStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + memset(matmul, 0, sizeof(MatmulStruct)); + matmul->base_.Prepare = MatmulBasePrepare; + matmul->base_.Resize = MatmulBaseResize; + matmul->base_.Release = MatmulBaseRelease; + matmul->base_.Compute = MatmulBaseCompute; + InitMatrixInfo(&(matmul->matrix_a_)); + InitMatrixInfo(&(matmul->matrix_b_)); + InitMatrixInfo(&(matmul->matrix_c_)); + matmul->is_sharing_pack_ = false; + matmul->pack_opt_ = false; + matmul->a_const_ = false; + matmul->b_const_ = false; + matmul->bias_need_repack_ = false; + matmul->out_need_aligned_ = false; + matmul->a_offset_ = NULL; + matmul->b_offset_ = NULL; + matmul->model_thread_nr_ = -1; + matmul->support_mul_batch_cut_by_row_ = false; + matmul->matmul_type_ = kMatmulFp32BaseCpu; + matmul->get_thread_cutting_policy_ = MatmulBaseGetThreadCuttingPolicy; + matmul->check_thread_cutting_by_row_ = MatmulBaseCheckThreadCuttingByRow; + matmul->get_thread_cutting_info_by_row_ = MatmulBaseGetThreadCuttingInfoByRow; + matmul->init_parameter_ = MatmulBaseInitParameter; + matmul->init_global_varibale_ = MatmulBaseInitGlobalVariable; + matmul->pack_matrix_a_impl_opt_ = MatmulBasePackMatrixAImplOpt; + matmul->pack_matrix_a_impl_ = MatmulBasePackMatrixAImpl; + matmul->pack_matrix_b_impl_ = MatmulBasePackMatrixBImpl; + matmul->parallel_run_by_batch_ = MatmulBaseParallelRunByBatch; + matmul->parallel_run_not_pack_by_batch_ = MatmulBaseParallelRunIsNotPackByBatch; + matmul->parallel_run_by_oc_ = MatmulBaseParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulBaseParallelRunByRow; + return (KernelBase *)matmul; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h new file mode 100644 index 00000000..d840697f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_BASE_H_ +#define NNACL_KERNEL_MATMUL_BASE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/kernel/matmul_struct.h" + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul); +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul); +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul); +int MatmulBaseInitParameter(MatmulStruct *matmul); +int MatmulBasePrepare(KernelBase *self); +int MatmulBaseResize(KernelBase *self); +int MatmulBaseRelease(KernelBase *self); + +KernelBase *CreateMatmulBase(); + +#endif // NNACL_KERNEL_MATMUL_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c new file mode 100644 index 00000000..dac46193 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/matmul_create.h" +#include "nnacl_c/kernel/matmul_base.h" +#if defined(ENABLE_AVX512) +#include "nnacl_c/kernel/matmul_avx512.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" +#endif + +#if defined(ENABLE_AVX) +#include "nnacl_c/kernel/matmul_avx.h" +#endif + +#if defined(ENABLE_SSE) +#include "nnacl_c/kernel/matmul_sse.h" +#endif + +#if defined(ENABLE_ARM32) +#include "nnacl_c/kernel/matmul_arm32.h" +#endif + +#if defined(ENABLE_ARM64) +#include "nnacl_c/kernel/matmul_arm64.h" +#endif + +KernelBase *CreateMatmulKernel() { + KernelBase *matmul = NULL; + +#if defined(ENABLE_AVX512) + AVX512_HARDWARE_SELF_AWARENESS_BEGIN + matmul = CreateMatmulAVX512(); + if (matmul != NULL) { + return matmul; + } + AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +#if defined(ENABLE_AVX) + matmul = CreateMatmulAVX(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_SSE) + matmul = CreateMatmulSSE(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM64) + matmul = CreateMatmulARM64(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM32) + matmul = CreateMatmulARM32(); + if (matmul != NULL) { + return matmul; + } +#endif + + matmul = CreateMatmulBase(); + return matmul; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h new file mode 100644 index 00000000..a5cf3b44 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_create.h @@ -0,0 +1,24 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_CREATE_H_ +#define NNACL_KERNEL_MATMUL_CREATE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulKernel(); + +#endif // NNACL_KERNEL_MATMUL_CREATE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c new file mode 100644 index 00000000..9ee236b8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl_c/kernel/matmul_sse.h" +#include "nnacl_c/kernel/matmul_base.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" + +void MatmulSSEInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row4MajorParallel : RowMajor2Col4MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C4NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; +} + +int MatmulSSEParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulSSEParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulSSE() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulSSEInitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulSSEParallelRunByOC; + matmul->parallel_run_by_batch_ = MatmulSSEParallelRunByBatch; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h new file mode 100644 index 00000000..78c7c0e3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_sse.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_SSE_H_ +#define NNACL_KERNEL_MATMUL_SSE_H_ +#ifdef ENABLE_SSE +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +KernelBase *CreateMatmulSSE(); + +#endif +#endif // NNACL_KERNEL_MATMUL_SSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h new file mode 100644 index 00000000..501249cf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/matmul_struct.h @@ -0,0 +1,133 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_STRUCT_H_ +#define NNACL_KERNEL_MATMUL_STRUCT_H_ + +#include "nnacl_c/kernel.h" +#include "nnacl_c/matmul_parameter.h" + +#define SPLIT_COUNT MAX_THREAD_NUM + +typedef struct MatrixInfo { + bool need_pack_; + bool has_packed_; // only valid for constant, only do once throughout the process. + bool origin_need_free_; // true when failing to infer shape, false in conv1x1 free in convolution delegate + int pack_size_; + float *origin_ptr_; // only valid for constant, which is synchronized with the 'has_origin'. + float *pack_ptr_; +} MatrixInfo; + +typedef struct MatmulSlice { + int row_s_; + int row_e_; + int col_s_; + int col_e_; +} MatmulSlice; + +typedef struct MatmulComputeParam { + int row_; + int col_; + int deep_; + int row_align_; + int col_align_; + int deep_align_; + int row_num_; + int col_tile_; + int row_tile_; + int col_step_; + int row_min_unit_; + int col_min_unit_; + int batch_stride_; + int pack_b_stride_; + int block_col_unit_; +} MatmulComputeParam; + +typedef struct MatmulStruct { + KernelBase base_; + MatmulComputeParam compute_; + MatmulType matmul_type_; + + /* model pool optimize */ + int model_thread_nr_; + + /* batch-matmul broadcast */ + int batch_; + int a_batch_; + int b_batch_; + int *a_offset_; /* batch_ size */ + int *b_offset_; /* batch_ size */ + + int split_points_[SPLIT_COUNT]; + + float *output_data_; + float *pack_b_src_; + float *pack_b_dst_; + + bool a_const_; + bool b_const_; + bool bias_need_repack_; + bool infer_shape_; + bool pack_opt_; + bool is_sharing_pack_; + bool out_need_aligned_; + bool weight_is_packed_; + bool support_mul_batch_cut_by_row_; + + MatrixInfo matrix_a_; + MatrixInfo matrix_b_; + MatrixInfo matrix_c_; + + void (*matrix_a_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + void (*matrix_b_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + + int (*pack_matrix_a_impl_opt_)(struct MatmulStruct *matmul); + int (*pack_matrix_a_impl_)(struct MatmulStruct *matmul); + int (*pack_matrix_b_impl_)(struct MatmulStruct *matmul); + + int (*init_parameter_)(struct MatmulStruct *matmul); + void (*init_global_varibale_)(struct MatmulStruct *matmul); + + bool (*check_thread_cutting_by_row_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_policy_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_info_by_row_)(struct MatmulStruct *matmul); + + void *shaing_manager_; + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); + void (*free_sharing_weight_)(void *manager, void *tensor_data); + + void (*gemm_not_pack_fun_)(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + + int (*parallel_run_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_oc_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_not_pack_by_batch_)(struct MatmulStruct *matmul, int task_id); + + /* optimize for avx512 */ + int col_split_points_size_; + int row_split_points_size_; + int col_split_points_[SPLIT_COUNT]; + int row_split_points_[SPLIT_COUNT]; + int matmul_slice_count_[SPLIT_COUNT]; + MatmulSlice matmul_slice_set_[SPLIT_COUNT][SPLIT_COUNT]; + int (*parallel_run_by_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepdot_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_col_row_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row1_deep1_gepdot_)(struct MatmulStruct *matmul, int task_id); +} MatmulStruct; + +#endif // NNACL_KERNEL_MATMUL_STRUCT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c new file mode 100644 index 00000000..9b49f325 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.c @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/nllloss.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/nllloss_fp32.h" +#include "nnacl_c/nllloss_parameter.h" + +int NlllossCompute(KernelBase *self) { + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + float *logits = self->in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(logits); + int *labels = self->in_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(labels); + float *weight = self->in_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(weight); + + float *loss = self->out_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(loss); + float *total_weight = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(total_weight); + + ReductionType reduction_type = ((NLLLossParameter *)self->param_)->reduction_type_; + return NLLLoss(logits, labels, weight, loss, total_weight, nllloss, reduction_type); +} + +int NlllossPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + TensorC *logits_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(logits_tensor); + nllloss->batch_ = logits_tensor->shape_[Index0]; + nllloss->class_num_ = logits_tensor->shape_[Index1]; + return NNACL_OK; +} + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type) { + NLLLossStruct *nllloss = (NLLLossStruct *)malloc(sizeof(NLLLossStruct)); + NNACL_CHECK_NULL_RETURN_NULL(nllloss); + nllloss->base_.Release = DefaultRelease; + nllloss->base_.Prepare = NlllossPrepare; + nllloss->base_.Resize = DefaultResize; + nllloss->base_.Compute = NlllossCompute; + return (KernelBase *)nllloss; +} + +REG_KERNEL_CREATOR(PrimType_NLLLoss, kNumberTypeFloat32, CreateNLLLoss) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h new file mode 100644 index 00000000..0527b049 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/nllloss.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NLLLOSS_H_ +#define NNACL_KERNEL_NLLLOSS_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int batch_; + int class_num_; +} NLLLossStruct; + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NLLLOSS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c new file mode 100644 index 00000000..9787dd6e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.c @@ -0,0 +1,126 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/non_max_suppression.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/non_max_suppression_parameter.h" +#include "nnacl_c/fp32/non_max_suppression_fp32.h" + +void NonMaxSuppressioExpandDims(int *dst_shape, int *origin_shape, size_t size) { + int i = 0; + for (; i < size; i++) { + dst_shape[i] = 1; + } + for (; i < Num3; i++) { + dst_shape[i] = origin_shape[i - size]; + } +} + +void NonMaxSuppressionGetParams(NonMaxSuppressionStruct *nm_suppression) { + // optional input order: max_output_per_class, iou_threshold, score_threshold + nm_suppression->max_output_per_class_ = 0; + if (nm_suppression->base_.in_size_ >= Num3) { + TensorC *max_output_tensor = nm_suppression->base_.in_[Index3]; + if (max_output_tensor != NULL && max_output_tensor->data_ != NULL) { + nm_suppression->max_output_per_class_ = *(int *)(max_output_tensor->data_); + } + } + + nm_suppression->iou_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num4) { + TensorC *iou_threshold_tensor = nm_suppression->base_.in_[Index4]; + if (iou_threshold_tensor != NULL && iou_threshold_tensor->data_ != NULL) { + nm_suppression->iou_threshold_ = *(float *)(iou_threshold_tensor->data_); + } + } + + nm_suppression->score_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num5) { + TensorC *score_threshold_tensor = nm_suppression->base_.in_[Index5]; + if (score_threshold_tensor != NULL && score_threshold_tensor->data_ != NULL) { + nm_suppression->score_threshold_ = *(float *)(score_threshold_tensor->data_); + } + } +} + +int NonMaxSuppressionCompute(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + NonMaxSuppressionGetParams(nm_suppression); + + TensorC *box_tensor = self->in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(box_tensor); + int box_dims[Num3] = {0}; // batch, box_num, 4 + bool simple_out = false; + if (box_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(box_dims, box_tensor->shape_, Num3 - box_tensor->shape_size_); + simple_out = true; + } + if (box_dims[Index2] != Num4) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID; + } + + TensorC *score_tensor = self->in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(score_tensor); + int score_dims[Num3] = {0}; // batch, class, box_num + if (score_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(score_dims, score_tensor->shape_, Num3 - score_tensor->shape_size_); + } + if (score_dims[Index0] != box_dims[Index0]) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH; + } + if (score_dims[Index2] != box_dims[Index1]) { + return NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH; + } + if (nm_suppression->base_.out_[OUTPUT_INDEX]->data_ != NULL) { + /* output shape and data set in compute */ + return NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA; + } + return NonMaxSuppressionSelecte(nm_suppression, simple_out, score_dims); +} + +int NonMaxSuppressionPrepare(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + // boxes, scores, max_output_boxes, iou_threshold, score_threshold + if (self->in_size_ < Num2 || self->in_size_ > Num5 || self->out_size_ != Num1) { + return NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID; + } + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NMSParameter *nmparam = (NMSParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(nmparam); + if (nmparam->center_point_box_ != 0 && nmparam->center_point_box_ != 1) { + return NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID; + } + nm_suppression->center_point_box_ = nmparam->center_point_box_; + return NNACL_OK; +} + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type) { + NonMaxSuppressionStruct *non_max_suppression = (NonMaxSuppressionStruct *)malloc(sizeof(NonMaxSuppressionStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_max_suppression); + non_max_suppression->base_.Release = DefaultRelease; + non_max_suppression->base_.Resize = DefaultResize; + non_max_suppression->base_.Prepare = NonMaxSuppressionPrepare; + non_max_suppression->base_.Compute = NonMaxSuppressionCompute; + return (KernelBase *)non_max_suppression; +} + +REG_KERNEL_CREATOR(PrimType_NonMaxSuppression, kNumberTypeFloat32, CreateNonMaxSuppression) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h new file mode 100644 index 00000000..39d5485b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_max_suppression.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ +#define NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int center_point_box_; + int max_output_per_class_; + float iou_threshold_; + float score_threshold_; +} NonMaxSuppressionStruct; + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c new file mode 100644 index 00000000..6d9c46ec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/non_zero.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int NonZeroCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_2D, NNACL_NON_ZERO_SHAPE_INVALID); + + bool *input_data = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + int *output_data = (int *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + int non_zero_nums = output->shape_[Index1]; + int non_zero_count = 0; + + int *coordiate_values = (int *)self->env_->Alloc(self->env_->allocator_, input->shape_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(coordiate_values); + + for (int i = 0; i < NNACLGetElementNum(input); i += 1) { + if (input_data[i]) { + for (size_t j = 0; j < input->shape_size_; j++) { + output_data[non_zero_count + (int)j * non_zero_nums] = coordiate_values[j]; + } + non_zero_count++; + } + for (size_t idx = input->shape_size_; idx >= 1; --idx) { + if (coordiate_values[idx - 1] != input->shape_[idx - 1] - 1) { + coordiate_values[idx - 1] = coordiate_values[idx - 1] + 1; + break; + } + coordiate_values[idx - 1] = 0; + } + } + + return NNACL_OK; +} + +KernelBase *CreateNonZero(OpParameter *param, int data_type) { + NonZeroStruct *non_zero = (NonZeroStruct *)malloc(sizeof(NonZeroStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_zero); + non_zero->base_.Release = DefaultRelease; + non_zero->base_.Prepare = DefaultPrepare2In1Out; + non_zero->base_.Resize = DefaultResize; + non_zero->base_.Compute = NonZeroCompute; + return (KernelBase *)non_zero; +} + +REG_KERNEL_CREATOR(PrimType_NonZero, kNumberTypeBool, CreateNonZero) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h new file mode 100644 index 00000000..383e67e0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/non_zero.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_ZERO_H_ +#define NNACL_KERNEL_NON_ZERO_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; +} NonZeroStruct; + +KernelBase *CreateNonZero(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_ZERO_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c new file mode 100644 index 00000000..f303e03d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.c @@ -0,0 +1,193 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/one_hot.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/one_hot_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/one_hot_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/one_hot_fp16.h" +#endif + +int OneHotRun(void *cdata, int task_id, float l, float r) { + OneHotStruct *one_hot = (OneHotStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + int *indices_data = (int *)one_hot->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_data); + + TensorC *output_tensor = one_hot->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_data = one_hot->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + return OneHotToFp32(indices_data, one_hot->on_value_, one_hot->off_value_, (float *)output_data, one_hot, task_id, + one_hot->base_.thread_nr_); +#ifdef ENABLE_FP16 + } else if (output_tensor->data_type_ == kNumberTypeFloat16) { + return OneHotToFp16(indices_data, (float16_t)one_hot->on_value_, (float16_t)one_hot->off_value_, + (float16_t *)output_data, one_hot, task_id, one_hot->base_.thread_nr_); +#endif + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int OneHotInitOnOffValueForFourInputs(OneHotStruct *one_hot) { + TensorC *on_value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(on_value_tensor); + void *on_value_data = on_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(on_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->on_value_ = *((float *)on_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->on_value_ = *((float16_t *)on_value_data); +#endif + } else { + return NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + TensorC *off_value_tensor = one_hot->base_.in_[FOURTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(off_value_tensor); + void *off_value_data = off_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(off_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = *((float *)off_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = *((float16_t *)off_value_data); +#endif + } else { + return NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int OneHotInitOnOffValueForThreeInputs(OneHotStruct *one_hot) { + TensorC *value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(value_tensor); + void *value_data = value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(value_data); + + if (value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = ((float *)value_data)[Index0]; + one_hot->on_value_ = ((float *)value_data)[Index1]; +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = ((float16_t *)value_data)[Index0]; + one_hot->on_value_ = ((float16_t *)value_data)[Index1]; +#endif + } else { + return NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int OneHotInitParamsAndOnOffValue(OneHotStruct *one_hot) { + TensorC *depth_tensor = one_hot->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(depth_tensor); + + if (depth_tensor->data_type_ == kNumberTypeInt32) { + const int *depth = (int *)depth_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(depth); + one_hot->depth_ = *depth; + } else { + return NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID; + } + + if (one_hot->base_.in_size_ == FOUR_TENSOR) { + // 4 inputs: indices, depth, on_value, off_value + one_hot->support_neg_index_ = false; + int ret = OneHotInitOnOffValueForFourInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } else { + // 3 inputs: indices, depth, off_on_value + one_hot->support_neg_index_ = true; + int ret = OneHotInitOnOffValueForThreeInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int OneHotCompute(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + int ret = OneHotInitParamsAndOnOffValue(one_hot); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, OneHotRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int OneHotPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != FOUR_TENSOR && self->in_size_ != THREE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + TypeIdC data_type = self->in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_FALSE(data_type != kNumberTypeInt32 && data_type != kNumberTypeInt64, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +int OneHotResize(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + TensorC *indices = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices); + + int origin_axis = ((OneHotParameter *)self->param_)->axis_; + one_hot->axis_ = origin_axis < 0 ? origin_axis + (int)indices->shape_size_ + 1 : origin_axis; + NNACL_CHECK_FALSE(one_hot->axis_ < 0 && one_hot->axis_ > (int)indices->shape_size_, NNACL_ONE_HOT_AXIS_INVALID); + + one_hot->outer_size_ = 1; + for (int i = 0; i < one_hot->axis_; i++) { + one_hot->outer_size_ *= indices->shape_[i]; + } + if (one_hot->outer_size_ == 0) { + return NNACL_ONE_HOT_OUTER_SIZE_INVALID; + } + one_hot->inner_size_ = NNACLGetElementNum(indices) / one_hot->outer_size_; + NNACL_CHECK_FALSE(one_hot->inner_size_ <= 0, NNACL_ONE_HOT_INNER_SIZE_INVALID); + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_OneHot), one_hot->inner_size_, one_hot->outer_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +KernelBase *CreateOneHot(OpParameter *param, int data_type) { + OneHotStruct *one_hot = (OneHotStruct *)malloc(sizeof(OneHotStruct)); + NNACL_CHECK_NULL_RETURN_NULL(one_hot); + one_hot->base_.Release = DefaultRelease; + one_hot->base_.Prepare = OneHotPrepare; + one_hot->base_.Resize = OneHotResize; + one_hot->base_.Compute = OneHotCompute; + return (KernelBase *)one_hot; +} + +REG_KERNEL_CREATOR(PrimType_OneHot, kNumberTypeInt32, CreateOneHot) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h new file mode 100644 index 00000000..d945485b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/one_hot.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONE_HOT_H_ +#define NNACL_KERNEL_ONE_HOT_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int axis_; + int depth_; + int outer_size_; + int inner_size_; + bool support_neg_index_; + float on_value_; + float off_value_; +} OneHotStruct; + +KernelBase *CreateOneHot(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONE_HOT_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c new file mode 100644 index 00000000..fff44c15 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.c @@ -0,0 +1,67 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/ones_like.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define ApproximateOnesLike(output, data_size) \ + for (size_t i = 0; i < data_size; ++i) { \ + output[i] = 1; \ + } + +int OnesLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_ptr = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + size_t num = (size_t)NNACLGetElementNum(output_tensor); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + float *output = (float *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#ifdef ENABLE_FP16 + if (output_tensor->data_type_ == kNumberTypeFloat16) { + float16_t *output = (float16_t *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#endif + if (output_tensor->data_type_ == kNumberTypeInt32) { + int *output = (int *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +KernelBase *CreateOnesLike(OpParameter *param, int data_type) { + OnesLikeStruct *ones_like = (OnesLikeStruct *)malloc(sizeof(OnesLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ones_like); + ones_like->data_type_ = data_type; + ones_like->base_.Release = DefaultRelease; + ones_like->base_.Prepare = DefaultPrepare1In1Out; + ones_like->base_.Resize = DefaultResize; + ones_like->base_.Compute = OnesLikeCompute; + return (KernelBase *)ones_like; +} + +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeInt32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat16, CreateOnesLike) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h new file mode 100644 index 00000000..1027b720 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ones_like.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONES_LIKE_H_ +#define NNACL_KERNEL_ONES_LIKE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct OnesLikeStruct { + KernelBase base_; + int data_type_; +} OnesLikeStruct; + +KernelBase *CreateOnesLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONES_LIKE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c new file mode 100644 index 00000000..0f742ddb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.c @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pad.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/common_func.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pad_fp16.h" +#endif +#include "nnacl_c/fp32/pad_fp32.h" + +int PadInitMirrorPadBlock(PadStruct *pad) { + int left_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + left_pads[i] = pad->paddings_[Num2 * i]; + } + + int input_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int output_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int separate_offset[DEFAULT_PAD_NDIMS] = {0}; + int separate_size = 0; + + /* init separate dims */ + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + input_separate_dims[separate_size] = pad->in_[i]; + output_separate_dims[separate_size] = pad->out_[i]; + separate_offset[separate_size] = left_pads[i]; + separate_size++; + } + + /* init separate stride */ + int output_separate_stride[DEFAULT_PAD_NDIMS] = {0}; + (void)GetStride(output_separate_stride, output_separate_dims, separate_size); + int remain_stride_size = 0; + int remain_size = 1; + int right_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; i++) { + right_pads[i] = output_separate_dims[i] - input_separate_dims[i] - separate_offset[i]; + } + + /* init pad region */ + int pad_region[DEFAULT_PAD_NDIMS] = {0}; + int pad_region_size = 0; + for (int i = remain_stride_size; i < separate_size; ++i) { + int r = 1; + r = (separate_offset[i] > 0) ? (r + 1) : r; + r = (right_pads[i] > 0) ? (r + 1) : r; + pad_region[pad_region_size++] = r; + } + int pad_region_stride[DEFAULT_PAD_NDIMS] = {0}; + int region_size = GetStride(pad_region_stride, pad_region, pad_region_size); + + /* init mirror block info */ + int max_block_size = remain_size * region_size * sizeof(MirrorPadBlock); + pad->mirror_pad_block_ = (MirrorPadBlock *)pad->base_.env_->Alloc(pad->base_.env_->allocator_, max_block_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pad->mirror_pad_block_); + + // 0: center, 1: left, 2: right + int pad_cord[DEFAULT_PAD_NDIMS] = {0}; + + for (int pos = 0; pos < remain_size; ++pos) { + const int dst_basic_offset = 0; + for (int index = 1; index < region_size; ++index) { + int dst_offset = dst_basic_offset; + int value = index; + for (size_t i = 0; i < pad_region_size && pad_region_stride[i] != 0; ++i) { + NNACL_CHECK_ZERO_RETURN_ERR(pad_region_stride[i]); + pad_cord[i] = value / pad_region_stride[i]; + value = value % pad_region_stride[i]; + } + MirrorPadBlock block; + const int size_offset = DEFAULT_PAD_NDIMS - pad_region_size; + for (size_t i = 0; i < pad_region_size; ++i) { + int di = size_offset + i; + int si = remain_stride_size + i; + if (di >= DEFAULT_PAD_NDIMS) { + continue; + } + switch (pad_cord[i]) { + case Num0: + dst_offset += separate_offset[si] * output_separate_stride[si]; + block.size_[di] = input_separate_dims[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num2: + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num1: + if (separate_offset[si] > 0) { + block.size_[di] = separate_offset[si]; + block.out_stride_[di] = output_separate_stride[si]; + } else { + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + } + break; + default: + break; + } + } + block.out_offset_ = dst_offset; + pad->mirror_pad_block_[pad->mirror_pad_block_size_++] = block; + } + } + return NNACL_OK; +} + +int PadExtendDims(int *dims, const int *origin_dims, int max_dim, int origin_dim, int init_value) { + NNACL_CHECK_NULL_RETURN_ERR(dims); + NNACL_CHECK_NULL_RETURN_ERR(origin_dims); + for (int i = 0; i < max_dim - origin_dim; ++i) { + dims[i] = init_value; + } + for (int i = max_dim - origin_dim; i < max_dim; ++i) { + dims[i] = origin_dims[i - (max_dim - origin_dim)]; + } + return NNACL_OK; +} + +int PadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + void *input = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *output = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output); + + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16(input, output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)input, (float *)output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + return NNACL_OK; +} + +int PadFastMirrorRunImpl(PadStruct *pad, int task_id) { + void *in = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *out = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out); + + /* copy center part */ + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16((float16_t *)in, (float16_t *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)in, (float *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + + /* calculate region part */ + for (int i = task_id; i < pad->mirror_pad_block_size_; i += pad->base_.thread_nr_) { + MirrorPadBlock *block = &pad->mirror_pad_block_[i]; + for (int a = 0; a < block->size_[FIRST_INPUT]; a++) { + int out_a_index = block->out_offset_ + a * block->out_stride_[FIRST_INPUT]; + for (int b = 0; b < block->size_[SECOND_INPUT]; b++) { + int out_b_index = out_a_index + b * block->out_stride_[SECOND_INPUT]; + for (int c = 0; c < block->size_[THIRD_INPUT]; ++c) { + int out_c_index = out_b_index + c * block->out_stride_[THIRD_INPUT]; + for (int d = 0; d < block->size_[FOURTH_INPUT]; ++d) { + int out_d_index = out_c_index + d * block->out_stride_[FOURTH_INPUT]; + for (int e = 0; e < block->size_[FIFTH_INPUT]; ++e) { + int start_index = out_d_index + e * block->out_stride_[FIFTH_INPUT]; + int end_index = start_index + block->size_[SIXTH_INPUT]; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, start_index, end_index); +#endif + } else { + MirrorPad(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, pad->mirror_offset_, + start_index, end_index); + } + } + } + } + } + } + } + return NNACL_OK; +} + +int MirrorPadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + /* Fast Mirror pad */ + if (pad->mirror_pad_block_size_ != 0) { + return PadFastMirrorRunImpl(pad, task_id); + } + + TensorC *input = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + TensorC *output = pad->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + /* Common Mirror pad */ + int unit = UP_DIV(NNACLGetElementNum(output), pad->base_.thread_nr_); + int begin = unit * task_id; + int end = NNACL_MIN(begin + unit, NNACLGetElementNum(output)); + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16((float16_t *)input_data, (float16_t *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, + pad->paddings_, pad->mirror_offset_, begin, end); +#endif + } else { + MirrorPad((float *)input_data, (float *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, begin, end); + } + return NNACL_OK; +} + +int PadCheckPaddings(const int *paddings, int length, const int *input_shape, int mode) { + NNACL_CHECK_NULL_RETURN_ERR(paddings); + NNACL_CHECK_NULL_RETURN_ERR(input_shape); + int offset = mode == PaddingMode_Symmetric ? 0 : 1; + for (int i = 0; i < length; ++i) { + int max_valid = input_shape[i] - offset; + if (paddings[i * Num2] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + if (paddings[i * Num2 + 1] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + } + return NNACL_OK; +} + +int PadCopyPaddingFromInput(PadStruct *pad) { + TensorC *input_tensor = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *padding_tensor = pad->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding_tensor); + int *padding_data = padding_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(padding_data); + + (void)PadExtendDims(pad->in_, input_tensor->shape_, DEFAULT_PAD_NDIMS, input_tensor->shape_size_, 1); + (void)PadExtendDims(pad->paddings_, padding_data, MAX_PAD_SIZE, NNACLGetElementNum(padding_tensor), 0); + pad->paddings_size_ = MAX_PAD_SIZE; + + return NNACL_OK; +} + +void PadCalculateStrides(PadStruct *pad) { + pad->in_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->in_strides_[i] = pad->in_[i + 1] * pad->in_strides_[i + 1]; + } + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + pad->out_[i] = pad->in_[i] + pad->paddings_[i * Num2] + pad->paddings_[i * Num2 + 1]; + } + pad->out_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->out_strides_[i] = pad->out_[i + 1] * pad->out_strides_[i + 1]; + } +} + +int PadHandleMirrorPad(PadStruct *pad) { + pad->mirror_offset_ = pad->pad_mode_ == PaddingMode_Reflect ? 1 : 0; + (void)PadCheckPaddings(pad->paddings_, DEFAULT_PAD_NDIMS, pad->in_, pad->pad_mode_); + PadCalculateStrides(pad); + return PadInitMirrorPadBlock(pad); +} + +int PadCompute(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *pad_value_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(pad_value_tensor); + NNACL_CHECK_FALSE(NNACLGetElementNum(pad_value_tensor) != 1, NNACL_PAD_PADDING_VALID_INVALID); + void *pad_valud = pad_value_tensor->data_; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + pad->constant_value_ = ((float16_t *)pad_valud)[Index0]; +#endif + } else { + pad->constant_value_ = ((float *)pad_valud)[Index0]; + } + } + + int ret = PadCopyPaddingFromInput(pad); + if (ret != NNACL_OK) { + return ret; + } + + if (pad->pad_mode_ == PaddingMode_Constant) { + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + size_t output_size = NNACLGetElementNum(output); + void *output_data = output->data_; + if (fabsf(pad->constant_value_ - 0.0f) < 1e-5) { + memset(output_data, 0, output_size * (int)DataTypeCSize(pad->data_type_)); + } else { + for (size_t i = 0; i < output_size; ++i) { + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + ((float16_t *)output_data)[i] = pad->constant_value_; +#endif + } else { + ((float *)output_data)[i] = pad->constant_value_; + } + } + } + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, PadImpl, self, self->thread_nr_); + return ret; + } + + /* not constant pad mod using mirror pad algorithm */ + ret = PadHandleMirrorPad(pad); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MirrorPadImpl, self, self->thread_nr_); + + self->env_->Free(self->env_->allocator_, pad->mirror_pad_block_); + pad->mirror_pad_block_ = NULL; + pad->mirror_pad_block_size_ = 0; + return ret; +} + +int PadResize(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *padding = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int rank = input->shape_size_; + NNACL_CHECK_FALSE(input->shape_size_ > DEFAULT_PAD_NDIMS, NNACL_PAD_SHAPE_INVALID); + NNACL_CHECK_FALSE(NNACLGetElementNum(padding) != rank + rank, NNACL_PAD_SHAPE_INVALID); + + if (pad->pad_mode_ == PaddingMode_Constant) { + (void)PadExtendDims(pad->in_, input->shape_, DEFAULT_PAD_NDIMS, rank, 1); + (void)PadExtendDims(pad->out_, output->shape_, DEFAULT_PAD_NDIMS, rank, 1); + + if (pad->paddings_size_ < MAX_PAD_SIZE) { + int ori_paddings[MAX_PAD_SIZE]; + memcpy(ori_paddings, pad->paddings_, MAX_PAD_SIZE * sizeof(int)); + (void)PadExtendDims(pad->paddings_, ori_paddings, MAX_PAD_SIZE, pad->paddings_size_, 0); + pad->paddings_size_ = MAX_PAD_SIZE; + } + } + return NNACL_OK; +} + +int PadPrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == TWO_TENSOR || self->in_size_ == THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_TRUE_RET(self->out_size_ == ONE_TENSOR, NNACL_ERR); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_FALSE(input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreatePad(OpParameter *param, int data_type) { + PadStruct *pad = (PadStruct *)malloc(sizeof(PadStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pad); + memset(pad, 0, sizeof(PadStruct)); + + pad->data_type_ = data_type; + + PadParameter *pad_param = (PadParameter *)param; + pad->pad_mode_ = pad_param->pad_mode_; + pad->constant_value_ = pad_param->constant_value_; + pad->paddings_size_ = pad_param->padding_length; + memcpy(pad->paddings_, pad_param->paddings_, MAX_PAD_SIZE * sizeof(int)); + + pad->base_.Release = DefaultRelease; + pad->base_.Prepare = PadPrepare; + pad->base_.Resize = PadResize; + pad->base_.Compute = PadCompute; + return (KernelBase *)pad; +} + +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat32, CreatePad) +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat16, CreatePad) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h new file mode 100644 index 00000000..157e3c87 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pad.h @@ -0,0 +1,51 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PAD_H_ +#define NNACL_KERNEL_PAD_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/pad_parameter.h" + +typedef struct MirrorPadBlock { + int out_offset_; + int out_stride_[DEFAULT_PAD_NDIMS]; + int size_[DEFAULT_PAD_NDIMS]; +} MirrorPadBlock; + +typedef struct PadStruct { + KernelBase base_; + int data_type_; + int mirror_offset_; + float constant_value_; + int pad_mode_; + int paddings_[MAX_PAD_SIZE]; + int paddings_size_; + int in_[DEFAULT_PAD_NDIMS]; + int out_[DEFAULT_PAD_NDIMS]; + int in_strides_[DEFAULT_PAD_NDIMS]; + int out_strides_[DEFAULT_PAD_NDIMS]; + MirrorPadBlock *mirror_pad_block_; + int mirror_pad_block_size_; +} PadStruct; + +KernelBase *CreatePad(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PAD_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c new file mode 100644 index 00000000..752021ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.c @@ -0,0 +1,159 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pooling.h" +#include +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pooling_fp16.h" +#endif + +int PoolingF16RunImpl(PoolingStruct *pooling, int task_id) { +#ifdef ENABLE_FP16 + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *input_ptr = (float16_t *)pooling->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float16_t *output_ptr = (float16_t *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (param->pool_mode_ == PoolMode_MaxPool) { + MaxPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + return NNACL_OK; + } else { + return AvgPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } +#endif + return NNACL_DISABLE_FP16; +} + +int PoolingRunImpl(PoolingStruct *pooling, int task_id) { + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input_tensor = pooling->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (input_tensor->format_ == Format_NC4HW4) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } else { + return AvgPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } + } else if (input_tensor->format_ == Format_NHWC) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } else { + return AvgPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } + } + + return NNACL_UNSUPPORTED_FORMAT; +} + +int PoolingImpl(void *cdata, int task_id, float l, float r) { + PoolingStruct *pooling = (PoolingStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + if (pooling->data_type_ == kNumberTypeFloat16) { + return PoolingF16RunImpl(pooling, task_id); + } else if (pooling->data_type_ == kNumberTypeFloat32) { + return PoolingRunImpl(pooling, task_id); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int PoolingCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PoolingImpl, self, self->thread_nr_); +} + +int PoolingResize(KernelBase *self) { + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + PoolingComputeParam *compute = &pooling->compute_; + PoolingParameter *param = (PoolingParameter *)self->param_; + + compute->input_batch_ = NNACLGetBatch(in_tensor); + compute->input_channel_ = NNACLGetChannel(in_tensor); + compute->input_h_ = NNACLGetHeight(in_tensor); + compute->input_w_ = NNACLGetWidth(in_tensor); + compute->output_batch_ = NNACLGetBatch(out_tensor); + compute->output_channel_ = NNACLGetChannel(out_tensor); + compute->output_h_ = NNACLGetHeight(out_tensor); + compute->output_w_ = NNACLGetWidth(out_tensor); + compute->window_h_ = param->window_h_; + compute->window_w_ = param->window_w_; + if (param->global_) { + compute->window_h_ = compute->input_h_; + compute->window_w_ = compute->input_w_; + } + return NNACL_OK; +} + +int PoolingPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < 1, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_ERR); + + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + PoolingParameter *param = (PoolingParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float minf = pooling->data_type_ == kNumberTypeFloat32 ? -FLT_MAX : -FLT16_MAX; + float maxf = pooling->data_type_ == kNumberTypeFloat32 ? FLT_MAX : FLT16_MAX; + + if (param->act_type_ == ActType_Relu) { + minf = 0.f; + } else if (param->act_type_ == ActType_Relu6) { + minf = 0.f; + maxf = 6.f; + } + pooling->compute_.minf = minf; + pooling->compute_.maxf = maxf; + + return NNACL_OK; +} + +KernelBase *CreatePooling(OpParameter *param, int data_type) { + PoolingStruct *pooling = (PoolingStruct *)malloc(sizeof(PoolingStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(pooling); + memset(pooling, 0, sizeof(PoolingStruct)); + pooling->data_type_ = data_type; + pooling->base_.Release = DefaultRelease; + pooling->base_.Prepare = PoolingPrepare; + pooling->base_.Resize = PoolingResize; + pooling->base_.Compute = PoolingCompute; + return (KernelBase *)pooling; +} + +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat32, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat32, CreatePooling) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h new file mode 100644 index 00000000..7a95f0fc --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pooling.h @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POOLING_H_ +#define NNACL_KERNEL_POOLING_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PoolingComputeParam { + int input_w_; + int input_h_; + int input_batch_; + int input_channel_; + int output_w_; + int output_h_; + int output_batch_; + int output_channel_; + int window_w_; + int window_h_; + float minf; + float maxf; +} PoolingComputeParam; + +typedef struct Pooling3DComputeParam { + PoolingComputeParam pooling_compute_param_; + int input_d_; + int output_d_; + int window_d_; +} Pooling3DComputeParam; + +typedef struct PoolingStruct { + KernelBase base_; + PoolingComputeParam compute_; + int data_type_; +} PoolingStruct; + +KernelBase *CreatePooling(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POOLING_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c new file mode 100644 index 00000000..39198cd3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/pow.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/power_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/power_fp16.h" +#endif + +int PowImpl(void *cdata, int task_id, float l, float r) { + PowStruct *pow = (PowStruct *)cdata; + TensorC *input0 = pow->base_.in_[FIRST_INPUT]; + TensorC *input1 = pow->base_.in_[SECOND_INPUT]; + TensorC *output = pow->base_.out_[OUTPUT_INDEX]; + + int size = NNACLGetElementNum(input0); + int stride = UP_DIV(size, pow->base_.thread_nr_); + int len = MSMIN(stride, size - stride * task_id); + if (len <= 0) { + return NNACL_OK; + } + bool broadcast = !ShapeEqual(input0->shape_, input0->shape_size_, input1->shape_, input1->shape_size_); + float scale = ((PowParameter *)pow->base_.param_)->scale_; + float shift = ((PowParameter *)pow->base_.param_)->shift_; + int task_stride = stride * task_id; + + uint8_t *exp_addr = (uint8_t *)input1->data_; + void *cur_exp = NULL; + if (broadcast) { + cur_exp = exp_addr; + } else { + cur_exp = exp_addr + task_stride * DataTypeCSize(pow->data_type_); + } + + if (pow->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return PowerFp16((float16_t *)input0->data_ + task_stride, (float16_t *)cur_exp, + (float16_t *)output->data_ + task_stride, len, scale, shift, broadcast); +#endif + } else if (pow->data_type_ == kNumberTypeFloat32) { + return Power((float *)input0->data_ + task_stride, (float *)cur_exp, (float *)output->data_ + task_stride, len, + scale, shift, broadcast); + } + return NNACL_POW_INVALID_DATA_TYPE; +} + +int PowCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PowImpl, self, self->thread_nr_); +} + +KernelBase *CreatePow(OpParameter *param, int data_type) { + PowStruct *pow = (PowStruct *)malloc(sizeof(PowStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pow); + pow->data_type_ = data_type; + pow->base_.Release = DefaultRelease; + pow->base_.Prepare = DefaultPrepare2In1Out; + pow->base_.Resize = DefaultResize; + pow->base_.Compute = PowCompute; + return (KernelBase *)pow; +} + +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat32, CreatePow) +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat16, CreatePow) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h new file mode 100644 index 00000000..b7788722 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/pow.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POW_H_ +#define NNACL_KERNEL_POW_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PowStruct { + KernelBase base_; + int data_type_; +} PowStruct; + +KernelBase *CreatePow(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POW_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c new file mode 100644 index 00000000..72499684 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.c @@ -0,0 +1,111 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/prelu.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/prelu_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/prelu_fp16.h" +#endif + +int PReluRun(void *cdata, int task_id, float l, float r) { + PReluStruct *prelu = (PReluStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + + int thread_num = prelu->base_.thread_nr_; + int num = prelu->channel_shared_ ? prelu->input_num_ : prelu->input_num_ / prelu->channel_num_; + int step = UP_DIV(num, thread_num); + int start = task_id * step; + int end = MSMIN(start + step, num); + + void *in_data = prelu->base_.in_[FIRST_INPUT]->data_; + void *out_data = prelu->base_.out_[OUTPUT_INDEX]->data_; + void *slope_data = prelu->base_.in_[SECOND_INPUT]->data_; + + if (prelu->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + if (prelu->channel_shared_) { + PReluShareChannelFp16((float16_t *)in_data, (float16_t *)out_data, ((float16_t *)slope_data)[0], start, end); + } else { + PReluFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)slope_data, start, end, prelu->channel_num_); + } +#endif + } else { + if (prelu->channel_shared_) { + PReluShareChannel((float *)in_data, (float *)out_data, ((float *)slope_data)[0], start, end); + } else { + PRelu((float *)in_data, (float *)out_data, (float *)slope_data, start, end, prelu->channel_num_); + } + } + return NNACL_OK; +} + +int PReluPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int PReluResize(KernelBase *self) { + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + prelu->input_num_ = NNACLGetElementNum(input); + prelu->channel_num_ = NNACLGetChannel(input); + return NNACL_OK; +} + +int PReluCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *slope = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(slope); + + int slope_num = NNACLGetElementNum(slope); + if (slope_num == Num1) { + prelu->channel_shared_ = true; + } else if (slope_num == NNACLGetChannel(input)) { + prelu->channel_shared_ = false; + } else { + return NNACL_PRELU_SLOPE_NUM_INVALID; + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, PReluRun, self, self->thread_nr_); +} + +KernelBase *CreatePRelu(OpParameter *param, int data_type) { + PReluStruct *prelu = (PReluStruct *)malloc(sizeof(PReluStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prelu); + memset(prelu, 0, sizeof(PReluStruct)); + prelu->data_type_ = data_type; + prelu->base_.Prepare = PReluPrepare; + prelu->base_.Resize = PReluResize; + prelu->base_.Compute = PReluCompute; + prelu->base_.Release = DefaultRelease; + return (KernelBase *)prelu; +} + +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat16, CreatePRelu) +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat32, CreatePRelu) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h new file mode 100644 index 00000000..e38d1d93 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prelu.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRELU_H_ +#define NNACL_KERNEL_PRELU_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PReluStruct { + KernelBase base_; + int data_type_; + int input_num_; + int channel_num_; + bool channel_shared_; +} PReluStruct; + +KernelBase *CreatePRelu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRELU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c new file mode 100644 index 00000000..05438d90 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/prior_box.h" +#include +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/prior_box_fp32.h" +#include "nnacl_c/tensor_c_utils.h" + +int PriorBoxInitOutput(PriorBoxStruct *prior_box, const PriorBoxParameter *param, const float *different_aspect_ratios, + int different_aspect_ratios_size) { + for (int i = 0; i < prior_box->fmap_h_; i++) { + float cy = i + param->offset; + for (int j = 0; j < prior_box->fmap_w_; j++) { + float cx = j + param->offset; + for (int32_t k = 0; k < param->min_sizes_size; k++) { + float min = param->min_sizes[k]; + prior_box->output_[prior_box->output_size_++] = (cx - min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + + if (param->max_sizes_size > 0) { + float max = param->max_sizes[k]; + NNACL_CHECK_FALSE(min * max <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float prime = sqrt(min * max); + prior_box->output_[prior_box->output_size_++] = (cx - prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + + for (int m = 0; m < different_aspect_ratios_size; m++) { + float v = different_aspect_ratios[m]; + if (fabs(v - 1.0f) < 1e-6) { + continue; + } + NNACL_CHECK_FALSE(v <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float as_square_root = sqrt(v); + NNACL_CHECK_FALSE(as_square_root <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float box_w = min * as_square_root; + float box_h = min / as_square_root; + prior_box->output_[prior_box->output_size_++] = (cx - box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + } + } + } + return NNACL_OK; +} + +int RunPriorBox(void *cdata, int task_id, float l, float r) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + return PriorBox(prior_box->output_, output_data, NNACLGetSize(output_tensor), task_id, prior_box->base_.thread_nr_); +} + +int PriorBoxRelease(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + if (prior_box->output_ != NULL) { + self->env_->Free(self->env_->allocator_, prior_box->output_); + prior_box->output_ = NULL; + prior_box->output_size_ = 0; + } + return NNACL_OK; +} + +int PriorBoxResize(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + PriorBoxParameter *param = (PriorBoxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *input0_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input0_tensor); + TensorC *input1_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input1_tensor); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + prior_box->fmap_w_ = NNACLGetWidth(input0_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_w_); + prior_box->fmap_h_ = NNACLGetHeight(input1_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_h_); + const int image_w = param->image_size_w > 0 ? param->image_size_w : NNACLGetWidth(input1_tensor); + const int image_h = param->image_size_h > 0 ? param->image_size_h : NNACLGetHeight(input1_tensor); + + prior_box->step_w_ = param->step_w > 0.0f ? param->step_w : (float)(image_w) / prior_box->fmap_w_; + prior_box->step_h_ = param->step_h > 0.0f ? param->step_h : (float)(image_h) / prior_box->fmap_h_; + + float *different_aspect_ratios = + (float *)self->env_->Alloc(self->env_->allocator_, param->aspect_ratios_size * sizeof(float) * Num2); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(different_aspect_ratios); + different_aspect_ratios[Index0] = 1.0f; + int different_aspect_ratios_size = 1; + + float *aspect_ratios = param->aspect_ratios; + for (int32_t i = 0; i < param->aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + + bool exist = false; + for (int k = 0; k < different_aspect_ratios_size; k++) { + if (fabs(ratio - different_aspect_ratios[k]) < 1e-6) { + exist = true; + } + } + + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size++] = ratio; + if (param->flip) { + NNACL_CHECK_FALSE(fabs(ratio) <= 1e-5, NNACL_PRIOR_BOX_RATIO_INVALID); + different_aspect_ratios[different_aspect_ratios_size++] = 1.0f / ratio; + } + } + } + + PriorBoxRelease(self); + int size = Num4 + Num4 + different_aspect_ratios_size; + size = size * prior_box->fmap_h_ * prior_box->fmap_w_ * param->min_sizes_size; + size = size + UP_ROUND(NNACLGetHeight(output_tensor), COMM_SHAPE_SIZE); + size = size * sizeof(float); + NNACL_CHECK_MALLOC_SIZE(size); + prior_box->output_ = (float *)self->env_->Alloc(self->env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(prior_box->output_); + prior_box->output_size_ = 0; + + int ret = PriorBoxInitOutput(prior_box, param, different_aspect_ratios, different_aspect_ratios_size); + if (ret != NNACL_OK) { + return ret; + } + + // do clip + if (param->clip) { + for (int i = 0; i < prior_box->output_size_; i++) { + float item = prior_box->output_[i]; + if (item > 1.0f) { + item = 1.0f; + } + if (item < 0.0f) { + item = 0.0f; + } + } + } + + // variance + for (int i = 0; i < NNACLGetHeight(output_tensor) / COMM_SHAPE_SIZE; i++) { + for (int j = 0; j < COMM_SHAPE_SIZE; j++) { + prior_box->output_[prior_box->output_size_++] = param->variances[j]; + } + } + return NNACL_OK; +} + +int PriorBoxCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, RunPriorBox, self, self->thread_nr_); +} + +KernelBase *CreatePriorBox(OpParameter *param, int data_type) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)malloc(sizeof(PriorBoxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prior_box); + memset(prior_box, 0, sizeof(PriorBoxStruct)); + + prior_box->base_.Prepare = DefaultPrepare2In1Out; + prior_box->base_.Resize = PriorBoxResize; + prior_box->base_.Release = PriorBoxRelease; + prior_box->base_.Compute = PriorBoxCompute; + return (KernelBase *)prior_box; +} + +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeFloat32, CreatePriorBox) +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeInt8, CreatePriorBox) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h new file mode 100644 index 00000000..5d728fdd --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/prior_box.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRIOR_BOX_H_ +#define NNACL_KERNEL_PRIOR_BOX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct PriorBoxStruct { + KernelBase base_; + float *output_; + int output_size_; + int fmap_h_; + int fmap_w_; + float step_h_; + float step_w_; +} PriorBoxStruct; + +KernelBase *CreatePriorBox(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRIOR_BOX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c new file mode 100644 index 00000000..a11d6d90 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/ragged_range.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/ragged_range_fp16.h" +#endif + +int RaggedRangeCompute(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + TensorC *input0 = self->in_[Index0]; + TensorC *input1 = self->in_[Index1]; + TensorC *input2 = self->in_[Index2]; + TensorC *output0 = self->out_[Index0]; + TensorC *output1 = self->out_[Index1]; + + if (input0->data_type_ == kNumberTypeFloat32) { + RaggedRangeFp32((float *)input0->data_, (float *)input1->data_, (float *)input2->data_, (int *)output0->data_, + (float *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeInt32) { + RaggedRangeInt((int *)input0->data_, (int *)input1->data_, (int *)input2->data_, (int *)output0->data_, + (int *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RaggedRangeFp16((float16_t *)input0->data_, (float16_t *)input1->data_, (float16_t *)input2->data_, + (int *)output0->data_, (float16_t *)output1->data_, ragged_range); +#endif + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int RaggedRangeResize(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + ragged_range->rows_ = self->out_[OUTPUT_INDEX]->shape_[Index0] - 1; + ragged_range->starts_is_scalar_ = self->in_[FIRST_INPUT]->shape_size_ == 0; + ragged_range->limits_is_scalar_ = self->in_[SECOND_INPUT]->shape_size_ == 0; + ragged_range->deltas_is_scalar_ = self->in_[THIRD_INPUT]->shape_size_ == 0; + return NNACL_OK; +} + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)malloc(sizeof(RaggedRangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ragged_range); + ragged_range->base_.Release = DefaultRelease; + ragged_range->base_.Prepare = DefaultPrepare3In2Out; + ragged_range->base_.Resize = RaggedRangeResize; + ragged_range->base_.Compute = RaggedRangeCompute; + return (KernelBase *)ragged_range; +} + +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeInt32, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat16, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat32, CreateRaggedRange) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h new file mode 100644 index 00000000..e19ea067 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/ragged_range.h @@ -0,0 +1,35 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RAGGED_RANGE_H_ +#define NNACL_KERNEL_RAGGED_RANGE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RaggedRangeStruct { + KernelBase base_; + int rows_; + bool starts_is_scalar_; + bool limits_is_scalar_; + bool deltas_is_scalar_; +} RaggedRangeStruct; + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RAGGED_RANGE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c new file mode 100644 index 00000000..812a50c1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/range.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/range_parameter.h" +#include "nnacl_c/fp32/range_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/range_fp16.h" +#endif + +int RangeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + int output_num = NNACLGetElementNum(output); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *delta = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(delta); + + if (input->data_type_ == kNumberTypeFloat32) { + Range((float *)output->data_, *(float *)input->data_, *(float *)delta->data_, output_num); + } else if (input->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RangeFp16((float16_t *)output->data_, *(float16_t *)input->data_, *(float16_t *)delta->data_, output_num); +#endif + } else if (input->data_type_ == kNumberTypeInt32) { + RangeInt((int *)output->data_, *(int *)input->data_, *(int *)delta->data_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } else { + if (input->data_type_ == kNumberTypeInt32) { + RangeParameter *param = (RangeParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + RangeInt((int *)output->data_, param->start_, param->delta_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + return NNACL_OK; +} + +KernelBase *CreateRange(OpParameter *param, int data_type) { + RangeStruct *range = (RangeStruct *)malloc(sizeof(RangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(range); + range->base_.Release = DefaultRelease; + range->base_.Prepare = DefaultPrepare1In1Out; + range->base_.Resize = DefaultResize; + range->base_.Compute = RangeCompute; + return (KernelBase *)range; +} + +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeInt32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat16, CreateRange) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h new file mode 100644 index 00000000..32ba2498 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/range.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANGE_H_ +#define NNACL_KERNEL_RANGE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RangeStruct { + KernelBase base_; +} RangeStruct; + +KernelBase *CreateRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANGE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c new file mode 100644 index 00000000..ce965034 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/rank.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int RankCompute(KernelBase *self) { + size_t rank = self->in_[FIRST_INPUT]->shape_size_; + void *output_data = self->out_[OUTPUT_INDEX]->data_; + if (self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + *(float16_t *)output_data = (float16_t)rank; +#endif + } else { + *(float *)output_data = (float)rank; + } + return NNACL_OK; +} + +KernelBase *CreateRank(OpParameter *param, int data_type) { + RankStruct *rank = (RankStruct *)malloc(sizeof(RankStruct)); + NNACL_CHECK_NULL_RETURN_NULL(rank); + rank->base_.Release = DefaultRelease; + rank->base_.Prepare = DefaultPrepare1In1Out; + rank->base_.Resize = DefaultResize; + rank->base_.Compute = RankCompute; + return (KernelBase *)rank; +} + +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat32, CreateRank) +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat16, CreateRank) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h new file mode 100644 index 00000000..ef2e55e5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/rank.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANK_H_ +#define NNACL_KERNEL_RANK_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct RankStruct { + KernelBase base_; +} RankStruct; + +KernelBase *CreateRank(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c new file mode 100644 index 00000000..4752357a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.c @@ -0,0 +1,434 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reduce.h" +#include +#include "nnacl_c/fp32/reduce_fp32.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +void InitialReduceKernelList(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceKernelList func_list[] = {{Reduce_Sum, ReduceSum, IntReduceSum, NULL, ReduceSumByLastAxis}, + {Reduce_Mean, ReduceMean, IntReduceMean, NULL, NULL}, + {Reduce_Max, ReduceMax, IntReduceMax, NULL, ReduceMaxByLastAxis}, + {Reduce_Min, ReduceMin, IntReduceMin, NULL, NULL}, + {Reduce_Prod, ReduceProd, IntReduceProd, NULL, NULL}, + {Reduce_SumSquare, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_ASum, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_All, NULL, NULL, ReduceAll, NULL}, + {Reduce_L2, ReduceL2Norm, NULL, NULL, NULL}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceKernelList); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce->compute_ = func_list[i]; + return; + } + } +} + +int CallReduceUnit(KernelBase *base, int task_id) { + ReduceStruct *reduce = (ReduceStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce->src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce->dst_data_); + + if (reduce->data_type_ == kNumberTypeFloat32) { + if (reduce->inner_size_ == 1 && reduce->compute_.float_last_axis_func_ != NULL) { + return reduce->compute_.float_last_axis_func_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } else { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.float_function_); + return reduce->compute_.float_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + } + + if (reduce->data_type_ == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.bool_function_); + return reduce->compute_.bool_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (bool *)(reduce->src_data_), (bool *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + if (reduce->data_type_ == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.int_function_); + return reduce->compute_.int_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (int *)(reduce->src_data_), (int *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + return NNACL_REDUCE_UNSUPPORTED_DATA_TYPE; +} + +int ReduceImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + ReduceStruct *reduce = (ReduceStruct *)cdata; + return reduce->call_uint_((KernelBase *)reduce, task_id); +} + +int CopyReduceyInputToOutput(ReduceStruct *reduce) { + int total_num = NNACLGetElementNum(reduce->base_.in_[FIRST_INPUT]); + NNACL_CHECK_FALSE(total_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + int block_num = UP_DIV(total_num, reduce->base_.thread_nr_); + int tmp_thread_num = UP_DIV(total_num, block_num); + NNACL_CHECK_FALSE(tmp_thread_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + ReshapeStruct reshape_struct; + reshape_struct.base_.in_ = reduce->base_.in_; + reshape_struct.base_.out_ = reduce->base_.out_; + reshape_struct.block_num_ = block_num; + reshape_struct.total_num_ = total_num; + reshape_struct.base_.thread_nr_ = tmp_thread_num; + return reduce->base_.env_->ParallelLaunch(reduce->base_.env_->thread_pool_, ParallelReshape, &reshape_struct, + tmp_thread_num); +} + +int MallocReduceTmpBuffer(ReduceStruct *reduce) { + // Clean pointers in data_buffer for free condition checking in FreeReduceTmpBuffer. + memset(reduce->data_buffers_, 0, reduce->data_buffers_size_ * sizeof(void *)); + + for (int i = 0; i < reduce->data_buffers_size_; i++) { + reduce->data_buffers_[i] = reduce->base_.env_->Alloc( + reduce->base_.env_->allocator_, reduce->data_buffer_sizes_[i] * DataTypeCSize(reduce->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reduce->data_buffers_[i]); + } + return NNACL_OK; +} + +void FreeReduceTmpBuffer(ReduceStruct *reduce) { + for (int i = 0; i < reduce->data_buffers_size_; i++) { + if (reduce->data_buffers_[i] != NULL) { + reduce->base_.env_->Free(reduce->base_.env_->allocator_, reduce->data_buffers_[i]); + } + reduce->data_buffers_[i] = NULL; + } +} + +int CalculateReduceCoeffOutput(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + + if (reduce->data_type_ != kNumberTypeFloat32) { + return NNACL_REDUCE_COEFF_DATA_TYPE_INVALID; + } + TensorC *out_tensor = reduce->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + int num = NNACLGetElementNum(out_tensor); + + float *out_data = (float *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)reduce->base_.param_)->coeff; + } + return NNACL_OK; +} + +void HandleReduceASumAndSumSquare(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + if (reduce->data_type_ == kNumberTypeInt32 || reduce->data_type_ == kNumberTypeBool) { + return; + } + + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float *data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int ReduceCheckInputsOutputs(ReduceStruct *reduce) { + NNACL_CHECK_FALSE(reduce->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(reduce->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + for (size_t i = 0; i < reduce->base_.in_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.in_[i]); + } + for (size_t i = 0; i < reduce->base_.out_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.out_[i]); + } + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (reduce->base_.in_size_ > ONE_TENSOR) { + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(axes_tensor); + NNACL_CHECK_FALSE(axes_tensor->data_type_ != kNumberTypeInt && axes_tensor->data_type_ != kNumberTypeInt32 && + axes_tensor->data_type_ != kNumberTypeInt64, + NNACL_REDUCE_AXES_TENSOR_ERROR); + } + return NNACL_OK; +} + +int ReduceCommonPrepare(ReduceStruct *reduce) { + int ret = ReduceCheckInputsOutputs(reduce); + if (ret != NNACL_OK) { + return ret; + } + + if (reduce->base_.in_size_ == ONE_TENSOR) { + reduce->num_axes_ = 0; + return NNACL_OK; + } + + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + reduce->num_axes_ = NNACLGetElementNum(axes_tensor); + if (axes_tensor->data_ != NULL && (reduce->num_axes_ <= 0 || reduce->num_axes_ > MAX_SHAPE_SIZE)) { + return NNACL_REDUCE_AXES_TENSOR_ERROR; + } + if (axes_tensor->data_ == NULL) { + reduce->num_axes_ = reduce->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = i; + } + } else { + if (axes_tensor->data_type_ == kNumberTypeInt32 || axes_tensor->data_type_ == kNumberTypeInt) { + NNACL_CHECK_FALSE(NNACLGetSize(axes_tensor) == 0, NNACL_REDUCE_AXES_TENSOR_ERROR); + (void)memcpy(reduce->axes_, axes_tensor->data_, NNACLGetSize(axes_tensor)); + } else { + int64_t *axes_data = axes_tensor->data_; + for (size_t i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = (int32_t)axes_data[i]; + } + } + } + + return NNACL_OK; +} + +int CheckReduceParameters(ReduceStruct *reduce) { + int input_shape_size = reduce->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_FALSE(reduce->num_axes_ > input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + for (int i = 0; i < reduce->num_axes_; i++) { + NNACL_CHECK_FALSE(reduce->axes_[i] < -input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(reduce->axes_[i] >= input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + if (reduce->axes_[i] < 0) { + reduce->axes_[i] += input_shape_size; + } + } + + if (((ReduceParameter *)reduce->base_.param_)->reduce_to_end_) { + // actual num of axes to reduce + reduce->num_axes_ = (int)(input_shape_size)-reduce->axes_[0]; + for (int i = 1; i < reduce->num_axes_; ++i) { + reduce->axes_[i] = reduce->axes_[0] + i; + } + } + + if (reduce->num_axes_ == 0) { + for (int i = 0; i < input_shape_size; i++) { + reduce->axes_[i] = i; + } + reduce->num_axes_ = input_shape_size; + } + return NNACL_OK; +} + +void ReduceCalculateInnerOuterSize(ReduceStruct *reduce) { + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + reduce->offset_size_ = 0; + + for (int i = 0; i < reduce->num_axes_; ++i) { + int axis = reduce->axes_[i]; + int outer_size = 1; + for (int j = 0; j < axis; j++) { + outer_size *= tmp_input_shape[j]; + } + reduce->outer_sizes_[reduce->offset_size_] = outer_size; + + int inner_size = 1; + for (int k = axis + 1; k < input_tensor->shape_size_; k++) { + inner_size *= tmp_input_shape[k]; + } + reduce->inner_sizes_[reduce->offset_size_] = inner_size; + reduce->axis_sizes_[reduce->offset_size_] = tmp_input_shape[axis]; + + reduce->offset_size_++; + tmp_input_shape[axis] = 1; + } +} + +void ReduceCalculateTmpBufferSize(ReduceStruct *reduce) { + reduce->data_buffers_size_ = 0; + + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + // calculate size of buffer to malloc for each reducing axis + for (int i = 0; i < reduce->num_axes_ - 1; i++) { + int axis = reduce->axes_[i]; + size_t size = 1; + for (size_t j = 0; j < input_tensor->shape_size_; j++) { + if (axis != (int)(j)) { + size *= (size_t)(tmp_input_shape[j]); + } + } + reduce->data_buffer_sizes_[reduce->data_buffers_size_++] = size; + tmp_input_shape[axis] = 1; + } +} + +void ReduceDecideIfOnlyCopy(ReduceStruct *reduce) { + ReduceModeC can_not_copy[] = {Reduce_SumSquare, Reduce_ASum, Reduce_All, Reduce_L2}; + for (int i = 0; i < sizeof(can_not_copy) / sizeof(ReduceModeC); i++) { + if (can_not_copy[i] == ((ReduceParameter *)reduce->base_.param_)->mode_) { + reduce->only_copy_ = false; + return; + } + } + + int *in_shape = reduce->base_.in_[FIRST_INPUT]->shape_; + + for (int i = 0; i < reduce->num_axes_; i++) { + int axis = reduce->axes_[i]; + if (in_shape[axis] != 1) { + reduce->only_copy_ = false; + return; + } + } + reduce->only_copy_ = true; + return; +} + +int ReducePrepare(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, ONE_TENSOR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, ONE_TENSOR); + + int ret = ReduceCommonPrepare(reduce); + if (ret != NNACL_OK) { + return ret; + } + + reduce->init_kernel_list_(self); + return NNACL_OK; +} + +int ReduceResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + int ret = CheckReduceParameters(reduce); + if (ret != NNACL_OK) { + return ret; + } + + ReduceDecideIfOnlyCopy(reduce); + ReduceCalculateTmpBufferSize(reduce); + ReduceCalculateInnerOuterSize(reduce); + + if (reduce->num_axes_ == 1) { + self->thread_nr_ = self->UpdateThread( + TC_TYPE(PrimType_ReduceFusion, ((ReduceParameter *)reduce->base_.param_)->mode_), + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], reduce->outer_sizes_[Index0], self->thread_nr_); + } else { + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_ReduceFusion, Reduce_Max + 1), 0, 0, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + } + return NNACL_OK; +} + +int ReduceCompute(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != reduce->data_type_, NNACL_ERR); + + if (reduce->only_copy_) { + return CopyReduceyInputToOutput(reduce); + } + + int ret = MallocReduceTmpBuffer(reduce); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + + reduce->src_data_ = self->in_[FIRST_INPUT]->data_; + reduce->handle_sum_square_(self); + for (int i = 0; i < reduce->num_axes_; i++) { + if (i != (reduce->num_axes_ - 1)) { + reduce->dst_data_ = reduce->data_buffers_[i]; + } else { + reduce->dst_data_ = self->out_[FIRST_INPUT]->data_; + } + reduce->outer_size_ = reduce->outer_sizes_[i]; + reduce->inner_size_ = reduce->inner_sizes_[i]; + reduce->axis_size_ = reduce->axis_sizes_[i]; + NNACL_CHECK_FALSE(reduce->axis_size_ == 0, NNACL_REDUCE_AXIS_SIZE_ERROR); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ReduceImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + reduce->src_data_ = reduce->dst_data_; + } + + ReduceParameter *param = (ReduceParameter *)reduce->base_.param_; + if (param->reduce_to_end_ && fabsf(param->coeff) > 1e-5) { + ret = reduce->calculate_coeff_(self); + } + + FreeReduceTmpBuffer(reduce); + return ret; +} + +KernelBase *CreateReduce(OpParameter *param, int data_type) { + ReduceStruct *reduce = (ReduceStruct *)malloc(sizeof(ReduceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce); + memset(reduce, 0, sizeof(ReduceStruct)); + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + reduce->handle_sum_square_ = HandleReduceASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceCoeffOutput; + reduce->init_kernel_list_ = InitialReduceKernelList; + reduce->call_uint_ = CallReduceUnit; + return (KernelBase *)reduce; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeBool, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeInt32, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat32, CreateReduce) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h new file mode 100644 index 00000000..18f75649 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reduce.h @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REDUCE_H_ +#define NNACL_KERNEL_REDUCE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ReduceKernelList { + int type_; + int (*float_function_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); + int (*int_function_)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, + int *dst_data, const int tid, const int thread_num); + int (*bool_function_)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data, + bool *dst_data, const int tid, const int thread_num); + int (*float_last_axis_func_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); +} ReduceKernelList; + +typedef struct ReduceStruct { + KernelBase base_; + bool only_copy_; + int num_axes_; + TypeIdC data_type_; + int axes_[MAX_SHAPE_SIZE]; + + void *data_buffers_[MAX_SHAPE_SIZE]; + size_t data_buffer_sizes_[MAX_SHAPE_SIZE]; + int data_buffers_size_; + ReduceModeC mode_; + + int outer_sizes_[MAX_SHAPE_SIZE]; + int inner_sizes_[MAX_SHAPE_SIZE]; + int axis_sizes_[MAX_SHAPE_SIZE]; + int offset_size_; + + int outer_size_; + int inner_size_; + int axis_size_; + + void *src_data_; + void *dst_data_; + ReduceKernelList compute_; + + void (*handle_sum_square_)(KernelBase *base); + void (*init_kernel_list_)(KernelBase *base); + int (*calculate_coeff_)(KernelBase *base); + int (*call_uint_)(KernelBase *base, int task_id); +} ReduceStruct; + +KernelBase *CreateReduce(OpParameter *param, int data_type); +int ReducePrepare(KernelBase *self); +int ReduceResize(KernelBase *self); +int ReduceCompute(KernelBase *self); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c new file mode 100644 index 00000000..bc6271e9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" + +int kMinCostPerThread = 16384; + +int ParallelReshape(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + ReshapeStruct *reshape = (ReshapeStruct *)param; + + int data_size = (int)DataTypeCSize(reshape->base_.in_[0]->data_type_); + uint8_t *in_start = (uint8_t *)(reshape->base_.in_[0]->data_) + task_id * reshape->block_num_ * data_size; + uint8_t *out_start = (uint8_t *)(reshape->base_.out_[0]->data_) + task_id * reshape->block_num_ * data_size; + int copy_num = reshape->block_num_; + if (task_id == (reshape->base_.thread_nr_ - 1)) { + copy_num = reshape->total_num_ - task_id * reshape->block_num_; + } + (void)memcpy(out_start, in_start, copy_num * data_size); + return NNACL_OK; +} + +int ReshapeResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReshapeStruct *reshape = (ReshapeStruct *)self; + reshape->total_num_ = NNACLGetElementNum(self->in_[0]); + if (reshape->total_num_ == 0) { + return NNACL_OK; + } + + self->thread_nr_ = MSMIN(self->thread_nr_, UP_DIV(reshape->total_num_, kMinCostPerThread)); + if (self->thread_nr_ < 1) { + self->thread_nr_ = 1; + } + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reshape->block_num_ = UP_DIV(reshape->total_num_, self->thread_nr_); + NNACL_CHECK_ZERO_RETURN_ERR(reshape->block_num_); + self->thread_nr_ = UP_DIV(reshape->total_num_, reshape->block_num_); + + return NNACL_OK; +} + +int ReshapeCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ParallelReshape, self, self->thread_nr_); +} + +KernelBase *CreateReshape(OpParameter *param, int data_type) { + ReshapeStruct *reshape = (ReshapeStruct *)malloc(sizeof(ReshapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reshape); + reshape->base_.Release = DefaultRelease; + reshape->base_.Prepare = DefaultPrepare1In1Out; + reshape->base_.Resize = ReshapeResize; + reshape->base_.Compute = ReshapeCompute; + return (KernelBase *)reshape; +} + +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeUInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt64, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeBool, CreateReshape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h new file mode 100644 index 00000000..dca87f68 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reshape.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RESHAPE_H_ +#define NNACL_KERNEL_RESHAPE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ReshapeStruct { + KernelBase base_; + int block_num_; + int total_num_; +} ReshapeStruct; + +int ParallelReshape(void *param, int task_id, float l, float r); + +KernelBase *CreateReshape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c new file mode 100644 index 00000000..ade5ccf5 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/reverse.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/reverse_parameter.h" +#include "nnacl_c/fp32/reverse_fp32.h" + +int ReverseStride(TensorC *input, int index) { + int stride = 1; + for (int i = index + 1; i < (int)input->shape_size_; i++) { + stride *= input->shape_[i]; + } + return stride; +} + +int ReverseRun(void *cdata, int task_id, float l, float r) { + ReverseStruct *reverse = (ReverseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + int offset = task_id * reverse->thread_stride_; + int count = NNACL_MIN(reverse->thread_stride_, reverse->data_size_ - offset); + if (count <= 0) { + return NNACL_OK; + } + + float *in_ptr = (float *)reverse->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_ptr); + float *out_ptr = (float *)reverse->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_ptr); + return Reverse(in_ptr + offset, out_ptr, reverse->thread_stride_, reverse->tmp_ + offset); +} + +int ReverseUpdateAxisInfo(ReverseStruct *reverse) { + ReverseParameter *reverse_param = (ReverseParameter *)reverse->base_.param_; + int in_shape_len = reverse->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reverse_param->num_axis_; ++i) { + if (reverse_param->axis_[i] < 0) { + reverse_param->axis_[i] += in_shape_len; + } + if (reverse_param->axis_[i] < 0 || reverse_param->axis_[i] >= in_shape_len) { + return NNACL_REVERSE_AXIS_VALUE_INVALID; + } + } + return NNACL_OK; +} + +int ReverseCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ReverseRun, self, self->thread_nr_); +} + +int ReversePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (((ReverseParameter *)self->param_)->num_axis_ < Num1) { + return NNACL_REVERSE_AXIS_INVALID; + } + return NNACL_OK; +} + +int ReverseRelease(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (reverse->tmp_ != NULL) { + self->env_->Free(self->env_->allocator_, reverse->tmp_); + reverse->tmp_ = NULL; + } + return NNACL_OK; +} + +int ReverseResize(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + // trans negative to positive axis + int ret = ReverseUpdateAxisInfo(reverse); + if (ret != NNACL_OK) { + return ret; + } + + reverse->data_size_ = NNACLGetElementNum(input); + if (NNACLGetElementNum(output) != reverse->data_size_) { + return NNACL_REVERSE_DATA_SIZE_INVALID; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, reverse->data_size_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reverse->thread_stride_ = UP_DIV(reverse->data_size_, self->thread_nr_); + + ReverseParameter *reverse_param = (ReverseParameter *)self->param_; + if (reverse_param->num_axis_ > input->shape_size_) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + if (input->shape_size_ > REVERSE_SHAPE_MAX_SIZE) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + + (void)self->Release(self); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(reverse->data_size_, sizeof(int), NNACL_ERR); + reverse->tmp_ = (int *)self->env_->Alloc(self->env_->allocator_, reverse->data_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reverse->tmp_); + memset(reverse->tmp_, 0, reverse->data_size_ * sizeof(int)); + + for (int i = 0; i < reverse_param->num_axis_; i++) { + int axis = reverse_param->axis_[i]; + int stride = ReverseStride(input, axis); + reverse->strides_[i] = stride; + reverse->in_count_[i] = input->shape_[axis]; + reverse->out_count_[i] = 1; + for (int j = 0; j < axis; j++) { + reverse->out_count_[i] *= input->shape_[j]; + } + } + + int out; + int in; + int C; + int m; + for (int i = 0; i < reverse->data_size_; ++i) { + int tmp = i; + for (int j = 0; j < reverse_param->num_axis_; ++j) { + C = reverse->in_count_[j]; + out = tmp / (C * reverse->strides_[j]); + in = tmp / reverse->strides_[j] - out * C; + m = tmp % reverse->strides_[j]; + tmp = out * C * reverse->strides_[j] + reverse->strides_[j] * (C - 1 - in) + m; + } + reverse->tmp_[i] = tmp; + } + + return NNACL_OK; +} + +KernelBase *CreateReverse(OpParameter *param, int data_type) { + ReverseStruct *reverse = (ReverseStruct *)malloc(sizeof(ReverseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reverse); + memset(reverse, 0, sizeof(ReverseStruct)); + reverse->base_.Release = ReverseRelease; + reverse->base_.Prepare = ReversePrepare; + reverse->base_.Resize = ReverseResize; + reverse->base_.Compute = ReverseCompute; + return (KernelBase *)reverse; +} + +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeFloat32, CreateReverse) +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeInt32, CreateReverse) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h new file mode 100644 index 00000000..b69760d6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/reverse.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REVERSE_H_ +#define NNACL_KERNEL_REVERSE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct { + KernelBase base_; + int thread_stride_; + int data_size_; + int *tmp_; + int strides_[COMM_SHAPE_SIZE]; + int in_count_[COMM_SHAPE_SIZE]; + int out_count_[COMM_SHAPE_SIZE]; +} ReverseStruct; + +KernelBase *CreateReverse(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_REVERSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c new file mode 100644 index 00000000..d87f6dd9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.c @@ -0,0 +1,333 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/utils_fp16.h" +#include "nnacl_c/fp16/scale_fp16.h" +#endif + +int ScaleRunF16(ScaleStruct *scale, int task_id, ActType act_type) { +#ifdef ENABLE_FP16 + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6Fp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + Fp16DoScaleRelu((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScaleFp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleInitInputDataType(ScaleStruct *scale) { + if (scale->data_type_ == kNumberTypeFloat32) { + return NNACL_OK; + } + +#ifdef ENABLE_FP16 + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (scale_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_scale_ = true; + scale->scale_ = GetOrAllocFp16Data(scale_tensor, scale->base_.env_, true); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + /* already done in prepare */ + return NNACL_OK; + } + + TensorC *offset_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (offset_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_offset_ = true; + scale->offset_ = GetOrAllocFp16Data(offset_tensor, scale->base_.env_, true); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleRunF32(ScaleStruct *scale, int task_id, ActType act_type) { + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + DoScaleRelu((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScale((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + default: + return NNACL_SCALE_UNSUPPORT_ACT_TYPE; + } + return NNACL_OK; +} + +int ScaleRun(void *cdata, int task_id, float l, float r) { + ScaleStruct *scale = (ScaleStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(scale); + ActType act_type = ((ScaleParameter *)scale->base_.param_)->activation_type_; + if (scale->data_type_ == kNumberTypeFloat16) { + return ScaleRunF16(scale, task_id, act_type); + } else if (scale->data_type_ == kNumberTypeFloat32) { + return ScaleRunF32(scale, task_id, act_type); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ScaleCalculateParameter(ScaleStruct *scale) { + TensorC *input_tensor = scale->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + TensorC *output_tensor = scale->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + scale->outer_size_ = 1; + scale->axis_size_ = 1; + scale->inner_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= input_tensor->shape_[i]; + } + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + scale->axis_size_ *= input_tensor->shape_[i + scale->axis_]; + } + for (size_t i = scale->axis_ + scale_tensor->shape_size_; i < input_tensor->shape_size_; i++) { + scale->inner_size_ *= input_tensor->shape_[i]; + } + + scale->base_.thread_nr_ = MSMIN(scale->base_.thread_nr_, scale->outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(scale->base_.thread_nr_); + + return NNACL_OK; +} + +int ScaleInitScaleOffset(ScaleStruct *scale) { + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + int data_type_size = DataTypeCSize(scale->data_type_); + + if (scale->base_.in_size_ == TWO_TENSOR) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->offset_); + memset(scale->offset_, 0, malloc_size); + } + + if (scale->data_type_ == kNumberTypeFloat16) { + /* handle fp16 scale and offset in compute */ + return NNACL_OK; + } + + if (scale_tensor->data_ != NULL) { + scale->malloc_scale_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->scale_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->scale_, scale_tensor->data_, malloc_size); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(scale->base_.in_size_ != THREE_TENSOR, NNACL_SCALE_INPUT_NUM_INVALID); + + TensorC *offset_tensor = scale->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + if (offset_tensor->data_ != NULL) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(offset_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->offset_, offset_tensor->data_, malloc_size); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + + return NNACL_OK; +} + +int ScaleCheckInputsOutputs(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + + for (size_t i = 0; i < self->in_size_; i++) { + TensorC *input_tensor = self->in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (input_tensor->data_type_ != kNumberTypeFloat32 && input_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + if (output_tensor->data_type_ != kNumberTypeFloat32 && output_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ScaleRelease(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + if (scale->malloc_scale_ && scale->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->scale_); + scale->scale_ = NULL; + scale->malloc_scale_ = false; + } + + if (scale->malloc_offset_ && scale->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->offset_); + scale->offset_ = NULL; + scale->malloc_offset_ = false; + } + return NNACL_OK; +} + +int ScaleResize(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + + int origin_axis = ((ScaleParameter *)self->param_)->axis_; + scale->axis_ = origin_axis < 0 ? origin_axis + input_tensor->shape_size_ : origin_axis; + + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + if (i + scale->axis_ >= input_tensor->shape_size_) { + return NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH; + } + if (input_tensor->shape_[i + scale->axis_] != scale_tensor->shape_[i]) { + return NNACL_SCALE_SCALE_SHAPE_UNMATCH; + } + } + + int ret = ScaleCalculateParameter(scale); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int ScaleCompute(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + scale->input_ = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->input_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + scale->output_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->output_); + + int ret = ScaleInitInputDataType(scale); + if (ret != NNACL_OK) { + return ret; + } + + if (!scale->malloc_scale_) { + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + scale->scale_ = scale_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->scale_); + } + + if (!scale->malloc_offset_) { + TensorC *offset_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + scale->offset_ = offset_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->offset_); + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ScaleRun, self, self->thread_nr_); +} + +int ScalePrepare(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + int ret = ScaleCheckInputsOutputs(self); + if (ret != NNACL_OK) { + return ret; + } + + ret = ScaleInitScaleOffset(scale); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +KernelBase *CreateScale(OpParameter *param, int data_type) { + ScaleStruct *scale = (ScaleStruct *)malloc(sizeof(ScaleStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(scale); + memset(scale, 0, sizeof(ScaleStruct)); + scale->data_type_ = data_type; + scale->scale_ = NULL; + scale->offset_ = NULL; + scale->malloc_scale_ = false; + scale->malloc_offset_ = false; + scale->base_.Prepare = ScalePrepare; + scale->base_.Resize = ScaleResize; + scale->base_.Compute = ScaleCompute; + scale->base_.Release = ScaleRelease; + return (KernelBase *)scale; +} + +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat16, CreateScale) +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat32, CreateScale) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h new file mode 100644 index 00000000..87849421 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/scale.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SCALE_H_ +#define NNACL_KERNEL_SCALE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ScaleStruct { + KernelBase base_; + int axis_; + int data_type_; + int axis_size_; + int outer_size_; + int inner_size_; + bool malloc_scale_; + bool malloc_offset_; + void *scale_; + void *offset_; + void *input_; + void *output_; +} ScaleStruct; + +KernelBase *CreateScale(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SCALE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c new file mode 100644 index 00000000..e9637fdf --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.c @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/shape.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +int ShapeCompute(struct KernelBase *self) { + ShapeStruct *shape = (ShapeStruct *)self; + memcpy(self->out_[OUTPUT_INDEX]->data_, self->in_[FIRST_INPUT]->shape_, shape->shape_size_); + return NNACL_OK; +} + +int ShapeResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + ShapeStruct *shape = (ShapeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(shape); + shape->shape_size_ = self->in_[FIRST_INPUT]->shape_size_ * sizeof(int); + return NNACL_OK; +} + +KernelBase *CreateShape(OpParameter *param, int data_type) { + ShapeStruct *shape = (ShapeStruct *)malloc(sizeof(ShapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(shape); + shape->base_.Release = DefaultRelease; + shape->base_.Prepare = DefaultPrepare1In1Out; + shape->base_.Resize = ShapeResize; + shape->base_.Compute = ShapeCompute; + return (KernelBase *)shape; +} + +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeBool, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat16, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeUInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt64, CreateShape) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h new file mode 100644 index 00000000..3cbc9aa2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/shape.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SHAPE_H_ +#define NNACL_KERNEL_SHAPE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ShapeStruct { + KernelBase base_; + int shape_size_; +} ShapeStruct; + +KernelBase *CreateShape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SHAPE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c new file mode 100644 index 00000000..ae1768c6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.c @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/size.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +int SizeCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + int *out_data = (int *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + out_data[Index0] = NNACLGetElementNum(in_tensor); + return NNACL_OK; +} + +KernelBase *CreateSize(OpParameter *param, int data_type) { + SizeStruct *size = (SizeStruct *)malloc(sizeof(SizeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(size); + size->base_.Release = DefaultRelease; + size->base_.Prepare = DefaultPrepare1In1Out; + size->base_.Resize = DefaultResize; + size->base_.Compute = SizeCompute; + return (KernelBase *)size; +} + +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeInt32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat16, CreateSize) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h new file mode 100644 index 00000000..32690bc6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/size.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SIZE_H_ +#define NNACL_KERNEL_SIZE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SizeStruct { + KernelBase base_; +} SizeStruct; + +KernelBase *CreateSize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SIZE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c new file mode 100644 index 00000000..28ebacc2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/base/slice_base.h" +#include "nnacl_c/nnacl_common.h" + +int SliceLaunch(void *cdata, int task_id, float l, float r) { + SliceStruct *slice = (SliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(slice); + void *in_data = slice->base_.in_[FIRST_INPUT]->data_; + void *out_data = slice->base_.out_[OUTPUT_INDEX]->data_; + DoSlice(in_data, out_data, slice, task_id, slice->base_.thread_nr_, slice->data_type_size_); + return NNACL_OK; +} + +int SliceResize(KernelBase *self) { + SliceStruct *slice = (SliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(slice); + + InitSliceStruct(slice, self->in_[Index0], self->in_[Index1], self->in_[Index2]); + + if (slice->param_length_ < DIMENSION_8D) { + PadSliceParameterTo8D(slice); + } + return NNACL_OK; +} + +int SliceCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + SliceStruct *slice = (SliceStruct *)self; + if (slice->size_[Index5] < self->thread_nr_) { + DoSliceNoParallel(in_tensor->data_, out_tensor->data_, slice, slice->data_type_size_); + return NNACL_OK; + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, SliceLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +KernelBase *CreateSlice(OpParameter *param, int data_type) { + SliceStruct *slice = (SliceStruct *)malloc(sizeof(SliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(slice); + slice->data_type_size_ = DataTypeCSize(data_type); + slice->base_.Release = DefaultRelease; + slice->base_.Prepare = DefaultPrepare3In1Out; + slice->base_.Resize = SliceResize; + slice->base_.Compute = SliceCompute; + return (KernelBase *)slice; +} + +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeInt32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat16, CreateSlice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h new file mode 100644 index 00000000..d4bd3ce4 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/slice.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SLICE_H_ +#define NNACL_KERNEL_SLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SliceStruct { + KernelBase base_; + int data_type_size_; + int32_t begin_[DIMENSION_8D]; + int32_t size_[DIMENSION_8D]; + int32_t shape_[DIMENSION_8D]; + int32_t end_[DIMENSION_8D]; + int32_t param_length_; +} SliceStruct; + +KernelBase *CreateSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c new file mode 100644 index 00000000..967b142d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.c @@ -0,0 +1,157 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/softmax.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/fp32/softmax_fp32.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/softmax_fp16.h" +#endif + +int SoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + SoftmaxStruct *softmax = (SoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + NNACL_CHECK_ZERO_RETURN_ERR(softmax->base_.thread_nr_); + int unit = UP_DIV(softmax->out_plane_size_, softmax->base_.thread_nr_); + + int *in_shape = softmax->base_.in_[FIRST_INPUT]->shape_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, unit, NNACL_ERR); + int begin = task_id * unit; + int end = MSMIN(begin + unit, softmax->out_plane_size_); + int channel = in_shape[softmax->axis_]; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin, channel, NNACL_ERR); + int offset = begin * channel; + + void *input_ptr = softmax->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = softmax->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxLastAxisFp16((float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + return SoftmaxLastAxis((float *)input_ptr + offset, (float *)output_ptr + offset, end - begin, channel); +} + +int SoftmaxRelease(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + if (softmax->sum_data_ != NULL) { + self->env_->Free(self->env_->allocator_, softmax->sum_data_); + } + softmax->sum_data_ = NULL; + return NNACL_OK; +} + +int InitSoftmaxParam(SoftmaxStruct *softmax) { + TensorC *in_tensor = softmax->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int *in_shape = in_tensor->shape_; + + softmax->n_dim_ = (int)in_tensor->shape_size_; + int origin_axis = ((SoftmaxParameter *)softmax->base_.param_)->axis_; + softmax->axis_ = origin_axis == -1 ? origin_axis + softmax->n_dim_ : origin_axis; + + NNACL_CHECK_TRUE_RET(softmax->axis_ >= 0, NNACL_SOFTMAX_AXIS_INVALID); + NNACL_CHECK_TRUE_RET(softmax->axis_ < (int)in_tensor->shape_size_, NNACL_SOFTMAX_AXIS_INVALID); + + int out_plane_size = 1; + for (int i = 0; i < softmax->axis_; ++i) { + out_plane_size *= in_shape[i]; + } + int in_plane_size = 1; + for (int i = softmax->axis_ + 1; i < softmax->n_dim_; i++) { + in_plane_size *= in_shape[i]; + } + + ExecEnv *env = softmax->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + softmax->in_plane_size_ = in_plane_size; + softmax->out_plane_size_ = out_plane_size; + + (void)softmax->base_.Release(&softmax->base_); + if (softmax->in_plane_size_ > 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(out_plane_size, in_plane_size, NNACL_ERR); + int sum_data_size = out_plane_size * in_plane_size; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(sum_data_size, (int)DataTypeCSize(softmax->data_type_), NNACL_ERR); + softmax->sum_data_ = env->Alloc(env->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int SoftmaxResize(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + InitSoftmaxParam(softmax); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Softmax), in_shape[softmax->axis_], in_shape[softmax->axis_], + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int SoftmaxCompute(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + if (softmax->in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SoftmaxLastAxisRun, softmax, self->thread_nr_); + } + + void *input_ptr = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(softmax->sum_data_); +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxFp16((float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)softmax->sum_data_, softmax->axis_, + softmax->n_dim_, self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; + } +#endif + Softmax((float *)input_ptr, (float *)output_ptr, (float *)softmax->sum_data_, softmax->axis_, softmax->n_dim_, + self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; +} + +KernelBase *CreateSoftmax(OpParameter *param, int data_type) { + SoftmaxStruct *softmax = (SoftmaxStruct *)malloc(sizeof(SoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(softmax); + memset(softmax, 0, sizeof(SoftmaxStruct)); + + softmax->sum_data_ = NULL; + softmax->data_type_ = data_type; + softmax->base_.Release = SoftmaxRelease; + softmax->base_.Prepare = DefaultPrepare1In1Out; + softmax->base_.Resize = SoftmaxResize; + softmax->base_.Compute = SoftmaxCompute; + return (KernelBase *)softmax; +} + +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat16, CreateSoftmax) +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat32, CreateSoftmax) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h new file mode 100644 index 00000000..f37d0e13 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/softmax.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SOFTMAX_H_ +#define NNACL_KERNEL_SOFTMAX_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SoftmaxStruct { + KernelBase base_; + int axis_; + int n_dim_; + int in_plane_size_; + int out_plane_size_; + void *sum_data_; + TypeIdC data_type_; + int unit_; +} SoftmaxStruct; + +int InitSoftmaxParam(SoftmaxStruct *softmax); +int SoftmaxRelease(struct KernelBase *self); +KernelBase *CreateSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SOFTMAX_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c new file mode 100644 index 00000000..be845ae1 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/splice.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/splice_parameter.h" +#include "nnacl_c/fp32/splice_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/splice_fp16.h" +#endif + +int SpliceCompute(struct KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != output->shape_size_, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(output->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + + SpliceParameter *param = (SpliceParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + int src_row = input->shape_[Index1]; + int src_col = input->shape_[Index2]; + int dst_row = output->shape_[Index1]; + int dst_col = output->shape_[Index2]; + + NNACL_CHECK_FALSE(src_col * param->context_dim_ != dst_col, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->context_dim_ * dst_row != param->forward_indexes_dim_, NNACL_SPLICE_SHAPE_INVALID); + + for (int i = 0; i < param->forward_indexes_dim_; ++i) { + if (param->forward_indexes_[i] >= src_row) { + return NNACL_SPLICE_SHAPE_INVALID; + } + } + + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + SpliceFp16((float16_t *)input_data, src_row, src_col, param, (float16_t *)output_data, dst_row, dst_col); + return NNACL_OK; + } +#endif + + SpliceFp32((float *)input_data, src_row, src_col, param, (float *)output_data, dst_row, dst_col); + return NNACL_OK; +} + +KernelBase *CreateSplice(OpParameter *param, int data_type) { + SpliceStruct *splice = (SpliceStruct *)malloc(sizeof(SpliceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(splice); + splice->base_.Release = DefaultRelease; + splice->base_.Prepare = DefaultPrepare1In1Out; + splice->base_.Resize = DefaultResize; + splice->base_.Compute = SpliceCompute; + return (KernelBase *)splice; +} + +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat32, CreateSplice) +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat16, CreateSplice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h new file mode 100644 index 00000000..45b9e39f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/splice.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SPLICE_H_ +#define NNACL_KERNEL_SPLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct SpliceStruct { + KernelBase base_; +} SpliceStruct; + +KernelBase *CreateSplice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SPLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c new file mode 100644 index 00000000..57454922 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.c @@ -0,0 +1,138 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/stack.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/stack_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/base/stack_base.h" +#include "nnacl_c/tensor_c_utils.h" + +static inline int GetCopyNum(const int *in_shape, int axis, int n_dim) { + int copy_num = 1; + if (axis > 0) { + for (int j = n_dim - 1; j > axis - 1; j--) { + copy_num *= in_shape[j]; + } + } else { + for (int i = 0; i < n_dim; ++i) { + copy_num *= in_shape[i]; + } + } + return copy_num; +} + +static inline int GetOuterSize(const int *in_shape, int axis) { + int outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= in_shape[i]; + } + return outer_size; +} + +int StackRelease(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + if (stack->buffers_ != NULL) { + self->env_->Free(self->env_->allocator_, stack->buffers_); + stack->buffers_ = NULL; + } + return NNACL_OK; +} + +int StackPrepare(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + stack->buffers_ = + (void **)self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(void *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack->buffers_); + return NNACL_OK; +} + +int StackResize(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + int origin_axis = ((StackParameter *)self->param_)->axis_; + stack->axis_ = origin_axis < 0 ? origin_axis + (int)input->shape_size_ + 1 : origin_axis; + + if (self->in_size_ == 1) { + NNACL_CHECK_FALSE(NNACLGetElementNum(input) <= 0, NNACL_STACK_TENSOR_SHAPE_INVALID); + stack->copy_size_ = (size_t)NNACLGetElementNum(input) * DataTypeCSize(stack->data_type_); + stack->outer_size_ = 1; + } else { + NNACL_CHECK_FALSE((int)input->shape_size_ < stack->axis_, NNACL_STACK_TENSOR_SHAPE_INVALID); + size_t copy_num = (size_t)GetCopyNum(input->shape_, stack->axis_, input->shape_size_); + stack->copy_size_ = copy_num * DataTypeCSize(stack->data_type_); + stack->outer_size_ = GetOuterSize(input->shape_, stack->axis_); + } + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Stack), stack->copy_size_, stack->copy_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + self->thread_nr_ = NNACL_MIN(UP_DIV(stack->outer_size_, NNACL_STACK_STEP), self->thread_nr_); + return NNACL_OK; +} + +int StackRun(void *cdata, int task_id, float l, float r) { + StackStruct *stack = (StackStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + NNACL_CHECK_TRUE_RET(stack->base_.thread_nr_ != 0, NNACL_ERR); + int step = UP_DIV(stack->outer_size_, stack->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, step, NNACL_ERR); + int start = task_id * step; + int end = NNACL_MIN(start + step, stack->outer_size_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stack->base_.in_size_ * (size_t)start, stack->copy_size_, NNACL_ERR); + + void *output_data = (void *)(stack->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + uint8_t *output = (uint8_t *)output_data + stack->base_.in_size_ * (size_t)start * stack->copy_size_; + + Stack(stack->buffers_, (void *)output, stack->base_.in_size_, stack->copy_size_, start, end); + return NNACL_OK; +} + +int StackCompute(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + for (size_t i = 0; i < self->in_size_; ++i) { + stack->buffers_[i] = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[i]); + } + stack->buffers_[self->in_size_] = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[self->in_size_]); + return self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); +} + +KernelBase *CreateStack(OpParameter *param, int data_type) { + StackStruct *stack = (StackStruct *)malloc(sizeof(StackStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack); + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackCompute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat32, CreateStack) +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeInt32, CreateStack) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h new file mode 100644 index 00000000..e02d1cef --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/stack.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_STACK_H_ +#define NNACL_KERNEL_STACK_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +#define NNACL_STACK_STEP 64 + +typedef struct StackStruct { + KernelBase base_; + TypeIdC data_type_; + int axis_; + int outer_size_; + size_t copy_size_; + void **buffers_; +} StackStruct; + +KernelBase *CreateStack(OpParameter *param, int data_type); +int StackRun(void *cdata, int task_id, float l, float r); +int StackRelease(KernelBase *self); +int StackPrepare(KernelBase *self); +int StackResize(KernelBase *self); + +#endif // NNACL_KERNEL_STACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c new file mode 100644 index 00000000..3db06715 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.c @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/strided_slice.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/kernel/reshape.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" + +#define MinStridedSlicePerThread 16384 + +int StridedSliceFaseRun(void *cdata, int task_id, float l, float r) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + uint8_t *input_data = strided_slice->base_.in_[FIRST_INPUT]->data_; + uint8_t *output_data = strided_slice->base_.out_[OUTPUT_INDEX]->data_; + int *in_shape = strided_slice->base_.in_[FIRST_INPUT]->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + int begin_index = strided_slice->begins_[strided_slice->split_axis_]; + int caled_num = task_id * strided_slice->cal_num_per_thread_; + int64_t inner_size = (int64_t)strided_slice->inner_size_; + + if (strided_slice->parallel_on_outer_) { + uint8_t *cur_in_ptr = input_data + (caled_num * in_shape[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * out_shape[strided_slice->split_axis_] * inner_size; + int cur_outer = (int)strided_slice->outer_ - caled_num; + if (cur_outer <= 0) { + return NNACL_OK; + } + if (cur_outer > strided_slice->cal_num_per_thread_) { + cur_outer = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, out_shape[strided_slice->split_axis_], + strided_slice->strides_[strided_slice->split_axis_], cur_outer, strided_slice->inner_size_, + (size_t)in_shape[strided_slice->split_axis_] * strided_slice->inner_size_); + return NNACL_OK; + } + + if (strided_slice->parallel_on_split_axis_) { + uint8_t *cur_in_ptr = + input_data + (caled_num * strided_slice->strides_[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * inner_size; + int cal_axis_num = out_shape[strided_slice->split_axis_] - caled_num; + if (cal_axis_num <= 0) { + return NNACL_OK; + } + if (cal_axis_num > strided_slice->cal_num_per_thread_) { + cal_axis_num = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, (uint32_t)cal_axis_num, strided_slice->strides_[strided_slice->split_axis_], 1, + strided_slice->inner_size_, 0); + return NNACL_OK; + } + + return NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD; +} + +int StridedSliceFastRun(StridedSliceStruct *strided_slice) { + // Update length of inner size, because data type of tensor may be changed + // from float32 to float16 during fp16 sub-graph partition process. + size_t data_type_size = DataTypeCSize(strided_slice->base_.in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_FALSE(data_type_size == 0, NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE); + strided_slice->inner_size_ = strided_slice->inner_ * data_type_size; + + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[OUTPUT_INDEX]->data_); + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, StridedSliceFaseRun, + strided_slice, strided_slice->base_.thread_nr_); +} + +bool StridedSliceMatchInOutShapeEqualPattern(StridedSliceStruct *strided_slice) { + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + if (strided_slice->strides_[i] < 0) { + return false; + } + } + + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + + if (in_tensor->data_type_ != out_tensor->data_type_) { + return false; + } + + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + if (in_tensor->shape_size_ < ONE_TENSOR) { + return false; + } + + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + return false; + } + if (in_tensor->shape_[i] == -1) { + return false; + } + } + return true; +} + +int StridedSliceSoftCopyInputToOutput(StridedSliceStruct *strided_slice) { + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + int total_num = NNACLGetElementNum(in_tensor); + NNACL_CHECK_FALSE(total_num == 0, NNACL_STRIDED_SLICE_INVALID_DATA_SIZE); + + strided_slice->base_.thread_nr_ = + NNACL_MIN(strided_slice->base_.thread_nr_, UP_DIV(total_num, MinStridedSlicePerThread)); + if (strided_slice->base_.thread_nr_ < 1) { + strided_slice->base_.thread_nr_ = 1; + } + + int block_num = UP_DIV(total_num, strided_slice->base_.thread_nr_); + strided_slice->base_.thread_nr_ = UP_DIV(total_num, block_num); + + if (in_tensor->data_ != out_tensor->data_) { + if (strided_slice->base_.thread_nr_ == 1) { + (void)memcpy(out_tensor->data_, in_tensor->data_, total_num * (int)DataTypeCSize(in_tensor->data_type_)); + return NNACL_OK; + } + ReshapeStruct reshape; + reshape.base_.in_ = strided_slice->base_.in_; + reshape.base_.out_ = strided_slice->base_.out_; + reshape.block_num_ = block_num; + reshape.total_num_ = total_num; + reshape.base_.thread_nr_ = strided_slice->base_.thread_nr_; + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, ParallelReshape, &reshape, + strided_slice->base_.thread_nr_); + } + return NNACL_OK; +} + +bool StridedSliceMatchFastPattern(StridedSliceStruct *strided_slice) { + // This function is seeking if that the number of only one dimension + // is different between input and output. If so, we can do some trick. + // Example 1: + // input shape info: [1, 80, 46, 40] + // output shape info: [1, 80, 20, 40] + // Example 2: + // input shape info: [1, 46, 40] + // output shape info: [1, 20, 40] + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + int axis_list[MAX_SHAPE_SIZE]; + int axis_list_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + axis_list[axis_list_size++] = (int)i; + } + } + if (axis_list_size == 1) { + strided_slice->split_axis_ = axis_list[Index0]; + return true; + } + return false; +} + +void StridedSliceInitFastRunParam(StridedSliceStruct *strided_slice) { + TensorC *input_tenspr = strided_slice->base_.in_[FIRST_INPUT]; + int *in_shape = input_tenspr->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + + // reset && cal inner, outer + strided_slice->outer_ = 1; + strided_slice->inner_ = 1; + for (int i = 0; i < strided_slice->split_axis_; ++i) { + strided_slice->outer_ *= (size_t)in_shape[i]; + } + for (size_t i = (size_t)strided_slice->split_axis_ + 1; i < input_tenspr->shape_size_; i++) { + strided_slice->inner_ *= (size_t)in_shape[i]; + } + + if (strided_slice->outer_ == 1) { + strided_slice->parallel_on_split_axis_ = true; + strided_slice->parallel_on_outer_ = false; + } else { + strided_slice->parallel_on_split_axis_ = false; + strided_slice->parallel_on_outer_ = true; + } + + strided_slice->base_.thread_nr_ = strided_slice->base_.UpdateThread( + TC_TYPE(PrimType_StridedSlice, strided_slice->parallel_on_outer_), 1, 1, + NNACLGetElementNum(strided_slice->base_.out_[OUTPUT_INDEX]), strided_slice->base_.thread_nr_); + + strided_slice->cal_num_per_thread_ = + strided_slice->parallel_on_split_axis_ + ? UP_DIV(out_shape[strided_slice->split_axis_], strided_slice->base_.thread_nr_) + : UP_DIV((int)strided_slice->outer_, strided_slice->base_.thread_nr_); +} + +int StridedSliceResize(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->shape_size_ > MAX_SHAPE_SIZE, NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE); + + StridedSliceParameter *param = (StridedSliceParameter *)self->param_; + memcpy(strided_slice->begins_, param->begins_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->ends_, param->ends_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->in_shape_, param->in_shape_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->strides_, param->strides_, MAX_SHAPE_SIZE * sizeof(int)); + strided_slice->in_shape_size_ = param->in_shape_length_; + + strided_slice->soft_copy_mode_ = StridedSliceMatchInOutShapeEqualPattern(strided_slice); + strided_slice->fast_run_ = StridedSliceMatchFastPattern(strided_slice); + if (strided_slice->fast_run_) { + StridedSliceInitFastRunParam(strided_slice); + } + + if (strided_slice->soft_copy_mode_ == false && strided_slice->fast_run_ == false) { + return PadStridedSliceParameterTo8D(strided_slice); + } + + return NNACL_OK; +} + +int StridedSliceCompute(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + if (strided_slice->soft_copy_mode_) { + return StridedSliceSoftCopyInputToOutput(strided_slice); + } + if (strided_slice->fast_run_) { + return StridedSliceFastRun(strided_slice); + } + + return DoStridedSliceIn8D(self->in_[FIRST_INPUT]->data_, self->out_[OUTPUT_INDEX]->data_, strided_slice); +} + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)malloc(sizeof(StridedSliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(strided_slice); + strided_slice->data_type_ = data_type; + strided_slice->base_.Release = DefaultRelease; + strided_slice->base_.Prepare = DefaultPrepare1In1Out; + strided_slice->base_.Resize = StridedSliceResize; + strided_slice->base_.Compute = StridedSliceCompute; + return (KernelBase *)strided_slice; +} + +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h new file mode 100644 index 00000000..5fe3fc80 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/strided_slice.h @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_STRIDED_SLICE_H_ +#define NNACL_KERNEL_STRIDED_SLICE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct StridedSliceStruct { + KernelBase base_; + TypeIdC data_type_; + bool fast_run_; + bool soft_copy_mode_; + bool parallel_on_outer_; + bool parallel_on_split_axis_; + + int split_axis_; + int in_shape_size_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + + size_t inner_; + size_t outer_; + size_t inner_size_; + int cal_num_per_thread_; +} StridedSliceStruct; + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_STRIDED_SLICE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c new file mode 100644 index 00000000..18e33950 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.c @@ -0,0 +1,182 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/tile_parameter.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/kernel/default_kernel_base.h" + +#define kDoubleInputsSize 2 + +int TileDoubleInputScenes(TileStruct *tile) { + TensorC *t = tile->base_.in_[SECOND_INPUT]; + if (t->data_ == NULL) { + tile->resize_done_ = false; + return NNACL_OK; + } + + NNACL_CHECK_FALSE(NNACLGetElementNum(t) > (int)tile->base_.in_[FIRST_INPUT]->shape_size_, + NNACL_TILE_SECOND_INPUT_NUM_INVALID); + NNACL_CHECK_FALSE(t->data_type_ != kNumberTypeInt && t->data_type_ != kNumberTypeInt32, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID); + + int *input1_addr = (int *)(t->data_); + for (int i = 0; i < NNACLGetElementNum(t); ++i) { + NNACL_CHECK_FALSE(input1_addr[i] <= 0, NNACL_TILE_SECOND_INPUT_VALUE_INVALID); + tile->dims_[i] = i; + tile->multiples_[i] = input1_addr[i]; + } + return NNACL_OK; +} + +int SimpleTileImpl(TileStruct *tile, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(tile->base_.thread_nr_); + size_t unit = UP_DIV(tile->fast_outer_size_, (size_t)tile->base_.thread_nr_); + if (unit == 0 && task_id > 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(unit, (size_t)task_id), NNACL_ERR); + size_t begin = unit * (size_t)(task_id); + size_t end = MSMIN(begin + unit, tile->fast_outer_size_); + TileSimple(tile->input_addr_, tile->output_addr_, begin, end, tile); + return NNACL_OK; +} + +int SimpleTile(void *cdata, int task_id, float l, float r) { + TileStruct *tile = (TileStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(tile); + return SimpleTileImpl(tile, task_id); +} + +int TileFillOneDimTileParam(TileStruct *tile) { + // check if tile exact one dim + int large_one_multiple_count = 0; + int multiple = 0; + int mul_index = 0; + + for (int i = 0; i < tile->in_dim_; ++i) { + if (tile->multiples_[i] > 1) { + large_one_multiple_count++; + multiple = tile->multiples_[i]; + mul_index = i; + } + } + tile->one_dim_tile_ = large_one_multiple_count == 1; + if (tile->one_dim_tile_) { + tile->fast_multiple_ = (size_t)multiple; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->in_shape_[mul_index], tile->in_strides_[mul_index]), NNACL_ERR); + tile->fast_stride_ = (size_t)(tile->in_shape_[mul_index] * tile->in_strides_[mul_index]); + NNACL_CHECK_FALSE(tile->fast_stride_ < 1, NNACL_TILE_INPUT_SHAPE_INVALID); + tile->fast_outer_size_ = (size_t)NNACLGetElementNum(tile->base_.in_[FIRST_INPUT]) / tile->fast_stride_; + } + tile->resize_done_ = true; + return NNACL_OK; +} + +int TileResize(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + TileParameter *param = (TileParameter *)(self->param_); + NNACL_CHECK_NULL_RETURN_ERR(tile); + + tile->dims_size_ = param->dims_size_; + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + tile->dims_[i] = param->dims_[i]; + tile->multiples_[i] = param->multiples_[i]; + } + + if (self->in_size_ == kDoubleInputsSize) { + int ret = TileDoubleInputScenes(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + TensorC *input = self->in_[0]; + TensorC *output = self->out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + tile->in_dim_ = (int)input->shape_size_; + NNACL_CHECK_TRUE_RET(tile->in_dim_ > 0 && tile->in_dim_ <= MAX_SHAPE_SIZE, NNACL_TILE_INPUT_SHAPE_INVALID); + NNACL_CHECK_FALSE((int)output->shape_size_ < tile->in_dim_, NNACL_TILE_INPUT_SHAPE_INVALID); + + for (int i = 0; i < tile->in_dim_; ++i) { + tile->in_shape_[i] = input->shape_[i]; + tile->out_shape_[i] = output->shape_[i]; + } + + ComputeStrides(tile->in_shape_, tile->in_strides_, tile->in_dim_); + ComputeStrides(tile->out_shape_, tile->out_strides_, tile->in_dim_); + + for (size_t i = 0; i < tile->dims_size_; i++) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->multiples_[i], tile->in_shape_[i]), NNACL_ERRCODE_MUL_OVERFLOW); + int ele_num = tile->multiples_[i] * tile->in_shape_[i] - 1; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->out_strides_[i], ele_num), NNACL_ERRCODE_MUL_OVERFLOW); + } + + int ret = TileFillOneDimTileParam(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (tile->one_dim_tile_) { + self->thread_nr_ = + self->UpdateThread(TC_TYPE(PrimType_TileFusion, 0), 0, 0, tile->fast_outer_size_, self->thread_nr_); + } + return NNACL_OK; +} + +int TileCompute(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + tile->input_addr_ = (uint8_t *)(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->input_addr_); + tile->output_addr_ = (uint8_t *)(self->out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->output_addr_); + + if (!tile->resize_done_) { + int ret = TileResize(self); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + NNACL_CHECK_FALSE(tile->resize_done_ == false, NNACL_TILE_RESIZE_IN_RUNTIME_FAILED); + } + + tile->data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_TRUE_RET(tile->data_size_ > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + if (tile->one_dim_tile_) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SimpleTile, self, self->thread_nr_); + } + + Tile(tile->input_addr_, tile->output_addr_, tile); + return NNACL_OK; +} + +KernelBase *CreateTile(OpParameter *param, int data_type) { + TileStruct *tile = (TileStruct *)malloc(sizeof(TileStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tile); + tile->resize_done_ = false; + tile->base_.Release = DefaultRelease; + tile->base_.Prepare = DefaultPrepare1In1Out; + tile->base_.Resize = TileResize; + tile->base_.Compute = TileCompute; + return (KernelBase *)tile; +} + +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeInt32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat16, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeBool, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeUInt8, CreateTile) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h new file mode 100644 index 00000000..e7100004 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tile.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TILE_H_ +#define NNACL_KERNEL_TILE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TileStruct { + KernelBase base_; + bool one_dim_tile_; + bool resize_done_; + int dims_[MAX_SHAPE_SIZE]; + size_t dims_size_; + uint8_t *input_addr_; + uint8_t *output_addr_; + + int multiples_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + int out_shape_[MAX_SHAPE_SIZE]; + int in_strides_[MAX_SHAPE_SIZE]; + int out_strides_[MAX_SHAPE_SIZE]; + + int in_dim_; + size_t data_size_; + size_t fast_outer_size_; + size_t fast_stride_; + size_t fast_multiple_; +} TileStruct; + +KernelBase *CreateTile(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TILE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c new file mode 100644 index 00000000..470ea301 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.c @@ -0,0 +1,358 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/transpose.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/transpose_fp16.h" +#endif + +/* opt perm: { 0, 2, 1 } */ +#define OPT_PERM_0 0 +#define OPT_PERM_1 2 +#define OPT_PERM_2 1 + +int TransposeComputeinMultiThread(TransposeStruct *transpose, int task_id) { + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + + if (transpose->opt_run_) { + transpose->nhwc2nchw_(in, out, transpose->opt_perm_[FIRST_INPUT], transpose->opt_perm_[SECOND_INPUT], + transpose->opt_perm_[THIRD_INPUT], task_id, transpose->base_.thread_nr_); + } else { + transpose->optimize_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, transpose->out_strides_, + transpose->num_axes_, task_id, transpose->base_.thread_nr_); + } + return NNACL_OK; +} + +int TransposeComputeinSingleThread(TransposeStruct *transpose) { + if (transpose->opt_run_ || transpose->num_axes_ > DIMENSION_6D) { + return TransposeComputeinMultiThread(transpose, 0); + } + + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + return transpose->compute_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, + transpose->out_strides_, transpose->data_num_, transpose->num_axes_); +} + +int ResetTransposeStatus(TransposeStruct *transpose) { + transpose->num_axes_ = 0; + if (transpose->base_.in_size_ == C2NUM) { + transpose->num_axes_ = NNACLGetElementNum(transpose->base_.in_[SECOND_INPUT]); + transpose->perm_size_ = transpose->base_.in_[SECOND_INPUT]->shape_[0]; + } + + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ > MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE; + } + + int trans_nd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; + int *perm_data; + if ((int)in_tensor->shape_size_ != transpose->num_axes_) { + perm_data = trans_nd; + if (in_tensor->shape_size_ == Num3 && transpose->num_axes_ == Num4) { + transpose->num_axes_ = Num3; + } + if (transpose->num_axes_ == 0) { + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + trans_nd[i] = (int)in_tensor->shape_size_ - 1 - (int)i; + } + transpose->num_axes_ = (int)in_tensor->shape_size_; + } + } else { + NNACL_CHECK_TRUE_RET(transpose->base_.in_size_ == TWO_TENSOR, NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID); + TensorC *perm_tensor = transpose->base_.in_[SECOND_INPUT]; + if (perm_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_TRANSPOSE_PERM_TENSOR_INVALID; + } + perm_data = (int *)(perm_tensor->data_); + NNACL_CHECK_NULL_RETURN_ERR(perm_data); + int ele_num = NNACLGetElementNum(perm_tensor); + for (int i = 0; i < ele_num; i++) { + for (int j = 0; j < ele_num; j++) { + if (i == perm_data[j]) { + break; + } + if (j == ele_num - 1) { + return NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID; + } + } + } + } + + NNACL_CHECK_TRUE_RET(transpose->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, NNACL_TRANSPOSE_PERM_DIMS_INVALID); + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->perm_[i] = perm_data[i]; + } + return NNACL_OK; +} + +void TransposeFreeSegments(int **segments, int segments_size) { + for (int i = 0; i < segments_size; i++) { + if (segments[i] != NULL) { + free(segments[i]); + segments[i] = NULL; + } + } +} + +int TransposeOptimizeShape(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + // first step, delete dimension where value is 1. + int in_shape_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int in_shape_temp_size = 0; + int perm_diff[MAX_TRANSPOSE_DIM_SIZE] = {0}; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_shape[i] != 1) { + in_shape_temp[in_shape_temp_size++] = in_shape[i]; + continue; + } + for (size_t j = 0; j < in_tensor->shape_size_; ++j) { + if (transpose->perm_[j] < (int)(i)) { + continue; + } + if (transpose->perm_[j] == (int)(i)) { + perm_diff[j] = (int)(i) + 1; + } else { + perm_diff[j] += 1; + } + } + } + + int perm_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int perm_temp_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + int diff = transpose->perm_[i] - perm_diff[i]; + if (diff < 0) { + continue; + } + perm_temp[perm_temp_size++] = diff; + } + + NNACL_CHECK_TRUE_RET(in_shape_temp_size == perm_temp_size, NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED); + + // second step, fuse continuous dimension.; + int axis_num = in_shape_temp_size; + int *segments[MAX_TRANSPOSE_DIM_SIZE]; + int segment_sizes[MAX_TRANSPOSE_DIM_SIZE]; + int segments_size = 0; + for (int i = 0; i < axis_num;) { + int segment[MAX_TRANSPOSE_DIM_SIZE]; + int segment_size = 0; + segment[segment_size++] = perm_temp[i]; + ++i; + for (; i < axis_num; ++i) { + if (perm_temp[i] - 1 != perm_temp[i - 1]) { + break; + } + segment[segment_size++] = perm_temp[i]; + } + + segments[segments_size] = malloc(segment_size * sizeof(int)); + if (segments[segments_size] == NULL) { + TransposeFreeSegments(segments, segments_size); + return NNACL_NULL_PTR; + } + memcpy(segments[segments_size], segment, segment_size * sizeof(int)); + segment_sizes[segments_size] = segment_size; + segments_size++; + } + + transpose->in_shape_size_ = segments_size; + transpose->perm_size_ = segments_size; + for (int i = 0; i < segments_size; i++) { + transpose->in_shape_[i] = 1; + transpose->perm_[i] = 0; + } + for (int i = 0; i < segments_size; ++i) { + for (int j = 0; j < segments_size; ++j) { + transpose->perm_[i] += (segments[j][FIRST_INPUT] < segments[i][FIRST_INPUT] ? 1 : 0); + } + for (int k = 0; k < segment_sizes[i]; ++k) { + transpose->in_shape_[transpose->perm_[i]] *= in_shape_temp[segments[i][k]]; + } + } + TransposeFreeSegments(segments, segments_size); + return NNACL_OK; +} + +void SetTransposeOptInfo(TransposeStruct *transpose) { + // now perm is [1, 0] or [0, 2, 1] + if (transpose->perm_size_ == C2NUM) { + transpose->opt_perm_[FIRST_INPUT] = 1; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } else { + transpose->opt_perm_[FIRST_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[SECOND_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } +} + +bool TransposeOpt(TransposeStruct *transpose) { + if (transpose->perm_size_ == DIMENSION_2D) { + return true; + } + if (transpose->perm_size_ == DIMENSION_3D && transpose->perm_[FIRST_INPUT] == OPT_PERM_0 && + transpose->perm_[SECOND_INPUT] == OPT_PERM_1 && transpose->perm_[THIRD_INPUT] == OPT_PERM_2) { + return true; + } + return false; +} + +int TransposeComputeOfflineInfo(TransposeStruct *transpose) { + transpose->num_axes_ = transpose->in_shape_size_; + NNACL_CHECK_TRUE_RET(transpose->num_axes_ >= DIMENSION_3D, NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE); + + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->out_shape_[i] = transpose->in_shape_[transpose->perm_[i]]; + } + transpose->strides_[transpose->num_axes_ - 1] = 1; + transpose->out_strides_[transpose->num_axes_ - 1] = 1; + transpose->data_num_ = NNACLGetElementNum(transpose->base_.in_[FIRST_INPUT]); + for (int i = transpose->num_axes_ - 2; i >= 0; i--) { + transpose->strides_[i] = transpose->in_shape_[i + 1] * transpose->strides_[i + 1]; + transpose->out_strides_[i] = transpose->out_shape_[i + 1] * transpose->out_strides_[i + 1]; + } + return NNACL_OK; +} + +int TransposeCopyInputToOutput(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = transpose->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_FALSE(NNACLGetSize(in_tensor) == 0, NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID); + if (in_tensor->data_ != out_tensor->data_) { + (void)memcpy(out_tensor->data_, in_tensor->data_, NNACLGetSize(in_tensor)); + } + return NNACL_OK; +} + +int TransposeImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + TransposeStruct *transpose = (TransposeStruct *)cdata; + return TransposeComputeinMultiThread(transpose, task_id); +} + +int TransposeCompute(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + if (!transpose->is_valid_) { + return TransposeCopyInputToOutput(transpose); + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + if (self->thread_nr_ == 1) { + return TransposeComputeinSingleThread(transpose); + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, TransposeImpl, self, self->thread_nr_); +} + +int TransposeResize(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + int ret = ResetTransposeStatus(transpose); + if (ret != NNACL_OK) { + return ret; + } + transpose->is_valid_ = (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->num_axes_ && + (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->perm_size_; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + ret = TransposeOptimizeShape(transpose); + if (ret != NNACL_OK) { + return ret; + } + + transpose->is_valid_ = transpose->perm_size_ > DIMENSION_1D; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + transpose->opt_run_ = TransposeOpt(transpose); + if (transpose->opt_run_) { + SetTransposeOptInfo(transpose); + return NNACL_OK; + } + + ret = TransposeComputeOfflineInfo(transpose); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = (!transpose->opt_run_ && transpose->num_axes_ <= DIMENSION_6D) ? 1 : self->thread_nr_; + return NNACL_OK; +} + +int TransposePrepare(struct KernelBase *self) { + int ret = DefaultPrepare1In1Out(self); + if (ret != NNACL_OK) { + return ret; + } + TransposeStruct *transpose = (TransposeStruct *)self; + TransposeParameter *param = (TransposeParameter *)transpose->base_.param_; + if (param->perm_size_ > INT32_MAX) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + transpose->perm_size_ = (int)param->perm_size_; + for (int i = 0; i < transpose->perm_size_; i++) { + transpose->perm_[i] = param->perm_[i]; + } + return NNACL_OK; +} + +KernelBase *CreateTranspose(OpParameter *param, int data_type) { + TransposeStruct *transpose = (TransposeStruct *)malloc(sizeof(TransposeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(transpose); + transpose->nhwc2nchw_ = PackNHWCToNCHWFp32; + transpose->optimize_ = TransposeDimsFp32; + transpose->compute_ = DoTransposeFp32; + transpose->base_.Release = DefaultRelease; + transpose->base_.Prepare = TransposePrepare; + transpose->base_.Resize = TransposeResize; + transpose->base_.Compute = TransposeCompute; + if (data_type == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + transpose->nhwc2nchw_ = PackNHWCToNCHWFp16; + transpose->optimize_ = TransposeDimsFp16; + transpose->compute_ = DoTransposeFp16; +#else + free(transpose); + return NULL; +#endif + } + return (KernelBase *)transpose; +} + +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat32, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat16, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeInt32, CreateTranspose) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h new file mode 100644 index 00000000..ca8e7a0d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/transpose.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRANSPOSE_H_ +#define NNACL_KERNEL_TRANSPOSE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/kernel.h" +#include "nnacl_c/transpose_parameter.h" + +typedef struct TransposeStruct { + KernelBase base_; + bool is_valid_; + int num_axes_; + int data_num_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + int perm_size_; + int in_shape_[MAX_TRANSPOSE_DIM_SIZE]; /* after shape optimize */ + int in_shape_size_; + int out_shape_[MAX_TRANSPOSE_DIM_SIZE]; + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + int opt_perm_[PERM_NUM_THREE]; // only valid when opt_run_ is true + bool opt_run_; // only true when perm is [1, 0] or [0, 2, 1] + + int (*compute_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); + void (*nhwc2nchw_)(const void *src, void *dst, int b, int hw, int c, int task_id, int thread); + void (*optimize_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread); +} TransposeStruct; + +KernelBase *CreateTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRANSPOSE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c new file mode 100644 index 00000000..4b64aada --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/tril.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TrilCompute(KernelBase *self) { + TrilStruct *tril = (TrilStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tril); + + int ret = TriuTrilGetKValue(self, &tril->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + switch (type_size) { + case sizeof(int64_t): { + TrilByte8(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TrilByte4(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TrilByte2(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TrilByte1(src_data, dst_data, tril->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTril(OpParameter *param, int data_type) { + TrilStruct *tril = (TrilStruct *)malloc(sizeof(TrilStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tril); + tril->base_.Release = DefaultRelease; + tril->base_.Prepare = DefaultPrepare1In1Out; + tril->base_.Resize = DefaultResize; + tril->base_.Compute = TrilCompute; + return (KernelBase *)tril; +} + +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeDouble, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeBool, CreateTril) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h new file mode 100644 index 00000000..189325b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/tril.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIL_H_ +#define NNACL_KERNEL_TRIL_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TrilStruct { + KernelBase base_; + int64_t k_; +} TrilStruct; + +KernelBase *CreateTril(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c new file mode 100644 index 00000000..a3121f23 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/triu.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/fp32/triu_tril_fp32.h" + +int TriuCompute(KernelBase *self) { + TriuStruct *triu = (TriuStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(triu); + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + int ret = TriuTrilGetKValue(self, &triu->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + switch (type_size) { + case sizeof(int64_t): { + TriuByte8(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TriuByte4(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TriuByte2(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TriuByte1(src_data, dst_data, triu->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTriu(OpParameter *param, int data_type) { + TriuStruct *triu = (TriuStruct *)malloc(sizeof(TriuStruct)); + NNACL_CHECK_NULL_RETURN_NULL(triu); + triu->base_.Release = DefaultRelease; + triu->base_.Prepare = DefaultPrepare1In1Out; + triu->base_.Resize = DefaultResize; + triu->base_.Compute = TriuCompute; + return (KernelBase *)triu; +} + +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeDouble, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeBool, CreateTriu) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h new file mode 100644 index 00000000..e710cb76 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/triu.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIU_H_ +#define NNACL_KERNEL_TRIU_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct TriuStruct { + KernelBase base_; + int64_t k_; +} TriuStruct; + +KernelBase *CreateTriu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIU_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c new file mode 100644 index 00000000..ee9f8218 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.c @@ -0,0 +1,66 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/unique.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/fp32/unique_fp32.h" +#include "nnacl_c/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/unique_fp16.h" +#endif + +int UniqueCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output0 = self->out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output0); + TensorC *output1 = self->out_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(output1); + + int num = NNACLGetElementNum(input); + int output0_len = 0; + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + UniqueFp16((float16_t *)input->data_, num, (float16_t *)output0->data_, &output0_len, (int *)output1->data_); + } +#endif + if (input->data_type_ == kNumberTypeInt32) { + UniqueInt((int *)input->data_, num, (int *)output0->data_, &output0_len, (int *)output1->data_); + } + if (input->data_type_ == kNumberTypeFloat32) { + Unique((float *)input->data_, num, (float *)output0->data_, &output0_len, (int *)output1->data_); + } + + output0->shape_changed_ = (output0->shape_[output0->shape_size_ - 1] != output0_len); + output0->shape_[output0->shape_size_ - 1] = output0_len; + return NNACL_OK; +} + +KernelBase *CreateUnique(OpParameter *param, int data_type) { + UniqueStruct *unique = (UniqueStruct *)malloc(sizeof(UniqueStruct)); + NNACL_CHECK_NULL_RETURN_NULL(unique); + unique->data_type_ = data_type; + unique->base_.Release = DefaultRelease; + unique->base_.Prepare = DefaultPrepare1In2Out; + unique->base_.Resize = DefaultResize; + unique->base_.Compute = UniqueCompute; + return (KernelBase *)unique; +} + +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeInt32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat16, CreateUnique) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h new file mode 100644 index 00000000..3083e925 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/unique.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_UNIQUE_H_ +#define NNACL_KERNEL_UNIQUE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct UniqueStruct { + KernelBase base_; + int data_type_; +} UniqueStruct; + +KernelBase *CreateUnique(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_UNIQUE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c new file mode 100644 index 00000000..15229018 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.c @@ -0,0 +1,298 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/where.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/fp32/where_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl_c/fp16/where_fp16.h" +#endif +#include "nnacl_c/base/broadcast_to.h" + +int WhereExcuteFp16(WhereStruct *where, int task_id) { +#ifdef ENABLE_FP16 + WhereWithTripleInputsFp16((float16_t *)where->x_, (float16_t *)where->y_, (float16_t *)where->output_, &where->args_, + task_id, where->base_.thread_nr_); +#endif + return NNACL_OK; +} + +int WhereExcute(WhereStruct *where, int task_id) { + WhereWithTripleInputs((float *)where->x_, (float *)where->y_, (float *)where->output_, &where->args_, task_id, + where->base_.thread_nr_); + return NNACL_OK; +} + +int WhereRun(void *cdata, int task_id, float l, float r) { + WhereStruct *where = (WhereStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(where); + + NNACL_CHECK_NULL_RETURN_ERR(where->x_); + NNACL_CHECK_NULL_RETURN_ERR(where->y_); + NNACL_CHECK_NULL_RETURN_ERR(where->output_); + NNACL_CHECK_NULL_RETURN_ERR(where->args_.condition_); + + if (where->data_type_ == kNumberTypeFloat16) { + return WhereExcuteFp16(where, task_id); + } + return WhereExcute(where, task_id); +} + +int WhereRunWithSingleInput(WhereStruct *where) { + TensorC *input = where->base_.in_[FIRST_INPUT]; + int32_t *int32_condition = NULL; + float *fp32_condition = NULL; + bool *bool_condition = NULL; + switch (where->data_type_) { + case kNumberTypeInt32: + int32_condition = (int32_t *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(int32_condition); + break; + case kNumberTypeFloat32: + fp32_condition = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(fp32_condition); + break; + case kNumberTypeBool: + bool_condition = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(bool_condition); + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + WhereArgs *where_args = &where->args_; + where_args->condition_num_ = NNACLGetElementNum(input); + where_args->rank_ = input->shape_size_; + int strides[MAX_SHAPE_SIZE]; + ComputeStrides(input->shape_, strides, where_args->rank_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(where_args->condition_num_, where_args->rank_, NNACL_ERR); + int data_num_int = where_args->condition_num_ * where_args->rank_; + NNACL_CHECK_TRUE_RET(data_num_int >= 0, NNACL_ERR); + size_t result_size = (size_t)data_num_int * sizeof(int32_t); + int32_t *result = where->base_.env_->Alloc(where->base_.env_->allocator_, result_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(result); + + int result_index = 0; + int true_num = 0; + for (int index = 0; index < where_args->condition_num_; index++) { + bool condition = false; + switch (where->data_type_) { + case kNumberTypeInt32: + condition = (bool)int32_condition[index]; + break; + case kNumberTypeFloat32: + condition = (bool)fp32_condition[index]; + break; + case kNumberTypeBool: + condition = (bool)bool_condition[index]; + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + if (condition) { + true_num++; + int dim = index; + for (int j = 0; j < where_args->rank_; j++) { + NNACL_CHECK_ZERO_RETURN_ERR(strides[j]); + result[result_index++] = dim / strides[j]; + dim %= strides[j]; + } + } + } + + TensorC *output = where->base_.out_[OUTPUT_INDEX]; + if (output->data_ != NULL) { + /* the data should be nullptr */ + where->base_.env_->Free(where->base_.env_->allocator_, output->data_); + } + int output_shape[] = {true_num, where_args->rank_}; + output->shape_changed_ = ShapeEqual(output->shape_, output->shape_size_, output_shape, Num2); + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, Num2 * sizeof(int)); + + if (true_num > 0) { + output->data_ = result; + } + return NNACL_OK; +} + +int WhereBroadCastForInput(WhereStruct *where, TensorC *condition, TensorC *x, TensorC *y, + void **condition_broadcast_buf, void **x_broadcast_buf, void **y_broadcast_buf, + TensorC *output) { + size_t broad_cast_buf_size = NNACLGetElementNum(output); + if (output->data_type_ == kNumberTypeFloat32) { + broad_cast_buf_size *= sizeof(float); + } else { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastShapeInfo condition_info; + condition_info.input_shape_size_ = condition->shape_size_; + condition_info.output_shape_size_ = output->shape_size_; + memcpy(condition_info.input_shape_, condition->shape_, condition->shape_size_ * sizeof(int)); + memcpy(condition_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo x_info; + x_info.input_shape_size_ = x->shape_size_; + x_info.output_shape_size_ = output->shape_size_; + memcpy(x_info.input_shape_, x->shape_, x->shape_size_ * sizeof(int)); + memcpy(x_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo y_info; + y_info.input_shape_size_ = y->shape_size_; + y_info.output_shape_size_ = output->shape_size_; + memcpy(y_info.input_shape_, y->shape_, y->shape_size_ * sizeof(int)); + memcpy(y_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + *condition_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*condition_broadcast_buf == NULL) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize8(condition->data_, &condition_info, *condition_broadcast_buf); + + *x_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*x_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(x->data_, &x_info, *x_broadcast_buf); + + *y_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*y_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + where->base_.env_->Free(where->base_.env_->allocator_, *x_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(y->data_, &y_info, *y_broadcast_buf); + return NNACL_OK; +} + +int WhereRunWithTripleInputs(WhereStruct *where) { + TensorC *condition = where->base_.in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(condition); + TensorC *x = where->base_.in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(x); + TensorC *y = where->base_.in_[Index2]; + NNACL_CHECK_NULL_RETURN_ERR(y); + TensorC *output = where->base_.out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int condition_nums = NNACLGetElementNum(condition); + int x_num = NNACLGetElementNum(x); + int y_num = NNACLGetElementNum(y); + int out_num = NNACLGetElementNum(output); + int num_max = condition_nums > x_num ? condition_nums : (x_num > y_num ? x_num : y_num); + + where->x_ = x->data_; + where->y_ = y->data_; + where->output_ = output->data_; + + WhereArgs *args = &where->args_; + args->condition_ = (bool *)condition->data_; + args->condition_num_ = condition_nums; + args->x_num_ = x_num; + args->y_num_ = y_num; + args->max_num_ = num_max; + + void *condition_broadcast_buf = NULL; + void *x_broadcast_buf = NULL; + void *y_broadcast_buf = NULL; + + if (out_num < num_max) { + return NNACL_WHERE_INVALID_OUT_NUM; + } + if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) || + ((y_num != 1) && (y_num != num_max))) { + if (condition_nums != NNACLGetElementNum(y) && condition->shape_size_ != y->shape_size_) { + int ret = WhereBroadCastForInput(where, condition, x, y, &condition_broadcast_buf, &x_broadcast_buf, + &y_broadcast_buf, output); + if (ret != NNACL_OK) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + int max_num = NNACLGetElementNum(output); + args->condition_ = (bool *)condition_broadcast_buf; + where->x_ = x_broadcast_buf; + where->y_ = y_broadcast_buf; + where->output_ = output->data_; + args->condition_num_ = max_num; + args->x_num_ = max_num; + args->y_num_ = max_num; + args->max_num_ = max_num; + } else { + /* The length of three inputs are not equal to 1 or length of output, which is unacceptable */ + return NNACL_WHERE_CONDITION_NUM_INVALID; + } + } + if (num_max <= 0) { + /* Error, inputs' length are zero */ + return NNACL_WHERE_NUM_MAX_INVALID; + } + int ret = + where->base_.env_->ParallelLaunch(where->base_.env_->thread_pool_, WhereRun, where, where->base_.thread_nr_); + if (condition_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, condition_broadcast_buf); + condition_broadcast_buf = NULL; + } + if (x_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, x_broadcast_buf); + x_broadcast_buf = NULL; + } + if (y_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, y_broadcast_buf); + y_broadcast_buf = NULL; + } + return ret; +} + +int WhereCompute(KernelBase *self) { + WhereStruct *where = (WhereStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(where); + + int ret = NNACL_ERR; + if (self->in_size_ == Num1) { + ret = WhereRunWithSingleInput(where); + } else if (self->in_size_ == Num3) { + ret = WhereRunWithTripleInputs(where); + } else { + ret = NNACL_WHERE_INPUT_NUM_INVALID; + } + return ret; +} + +int WherePrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == Num1 || self->in_size_ == Num3, NNACL_WHERE_INPUT_NUM_INVALID); + NNACL_CHECK_TRUE_RET(self->out_size_ == Num1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateWhere(OpParameter *param, int data_type) { + WhereStruct *where = (WhereStruct *)malloc(sizeof(WhereStruct)); + NNACL_CHECK_NULL_RETURN_NULL(where); + memset(where, 0, sizeof(WhereStruct)); + where->data_type_ = data_type; + where->base_.Prepare = WherePrepare; + where->base_.Compute = WhereCompute; + where->base_.Resize = DefaultResize; + where->base_.Release = DefaultRelease; + return (KernelBase *)where; +} + +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeBool, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeInt32, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat16, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat32, CreateWhere) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h new file mode 100644 index 00000000..f2819969 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/where.h @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_WHERE_H_ +#define NNACL_KERNEL_WHERE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct WhereArgs { + int condition_num_; + int x_num_; + int y_num_; + int max_num_; + int rank_; + bool *condition_; +} WhereArgs; + +typedef struct WhereStruct { + KernelBase base_; + WhereArgs args_; + int data_type_; + void *x_; + void *y_; + void *output_; +} WhereStruct; + +KernelBase *CreateWhere(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_WHERE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c new file mode 100644 index 00000000..946b92ea --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.c @@ -0,0 +1,43 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/kernel/zeros_like.h" +#include "nnacl_c/kernel/default_kernel_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/tensor_c_utils.h" + +int ZerosLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_NULL_RETURN_ERR(output->data_); + (void)memset(output->data_, 0, NNACLGetSize(output)); + return NNACL_OK; +} + +KernelBase *CreateZerosLike(OpParameter *param, int data_type) { + ZerosLikeStruct *zeros_like = (ZerosLikeStruct *)malloc(sizeof(ZerosLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(zeros_like); + zeros_like->base_.Release = DefaultRelease; + zeros_like->base_.Prepare = DefaultPrepare1In1Out; + zeros_like->base_.Resize = DefaultResize; + zeros_like->base_.Compute = ZerosLikeCompute; + return (KernelBase *)zeros_like; +} + +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeInt32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat16, CreateZerosLike) diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h new file mode 100644 index 00000000..24085b3a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/kernel/zeros_like.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ZEROS_LIKE_H_ +#define NNACL_KERNEL_ZEROS_LIKE_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/kernel.h" + +typedef struct ZerosLikeStruct { + KernelBase base_; +} ZerosLikeStruct; + +KernelBase *CreateZerosLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ZEROS_LIKE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h new file mode 100644 index 00000000..8476503c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/l2_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_L2NORM_PARAMETER_H_ +#define NNACL_L2NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct L2NormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int axis_[MAX_SHAPE_SIZE]; + // shape correlative + size_t axis_num_; + int data_num_; + int *shape_; + size_t shape_num_; + // other parameter + ActType act_type_; +} L2NormParameter; + +typedef struct { + QuantArg in_; + QuantArg out_; +} L2NormQuantArg; + +#endif // NNACL_L2NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h new file mode 100644 index 00000000..e2b7bc45 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/layer_norm_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LAYER_NORM_PARAMETER_H_ +#define NNACL_LAYER_NORM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +typedef struct LayerNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormParameter; + +typedef struct LayerNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} LayerNormQuantArg; + +#endif // NNACL_LAYER_NORM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h new file mode 100644 index 00000000..ebb5f6fb --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/local_response_norm_parameter.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LRM_PARAMETER_H_ +#define NNACL_LRM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LocalResponseNormParameter { + OpParameter op_parameter_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +} LocalResponseNormParameter; + +#endif // NNACL_LRM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h new file mode 100644 index 00000000..b3bb2e5b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lsh_projection_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LSH_PROJECTION_PARAMETER_H_ +#define NNACL_LSH_PROJECTION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LshProjectionParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int hash_shape_[2]; + // other parameter + int lsh_type_; + int feature_num_; + char **hash_buffs_; + size_t hash_buff_size_; + int64_t thread_stride_; +} LshProjectionParameter; + +#endif // NNACL_LSH_PROJECTION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h new file mode 100644 index 00000000..563b1c29 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/lstm_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LSTM_PARAMETER_H_ +#define NNACL_LSTM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct LstmParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; + int project_size_; + int output_size_; + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int proj_col_align_; + bool has_bias_; +} LstmParameter; + +#endif // NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h new file mode 100644 index 00000000..e23edf89 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/matmul_parameter.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MATMUL_H_ +#define NNACL_MATMUL_H_ + +#include "nnacl_c/op_base.h" + +typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); + +typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); + +typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, + const int32_t *filter_zp); + +typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType; + +typedef enum MatmulType { + // reserve 0 for base op + kNotImplemented = 0, + kMatmulInt8Cpu, + kMatmulDynamicInt8Cpu, + kMatmulDynamicSdotInt8Cpu, + kMatmulFp32BaseCpu, + kMatmulFp32Arm64Cpu, +} MatmulType; + +typedef struct MatMulParameter { + // Primitive parameter + OpParameter op_parameter_; + bool has_bias_; + bool use_axis_; + bool a_transpose_; /* false : row-major */ + bool b_transpose_; /* true : col-major */ + ActType act_type_; + + // other parameter + int row_; + int col_; + int row_4_; + int row_16_; + int row_align_; + int col_8_; + int col_align_; + int deep_; + int deep_4_; + int deep_16_; + int deep_align_; + int batch; + bool a_const_; + bool b_const_; + int axis_; + MatmulType matmul_type_; +} MatMulParameter; + +typedef struct MatmulQuantParameter { + QuantArg input_; + QuantArg weight_; + QuantArg output_; + int32_t out_act_min_; + int32_t out_act_max_; + float *filter_scale_; + int32_t *filter_zp_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; +} MatmulQuantParameter; + +typedef struct MatmulDynamicQuantParameter { + float *input_scale_; + int32_t *input_zp_; + float *filter_scale_; + int32_t *filter_zp_; +} MatmulDynamicQuantParameter; + +#endif // NNACL_MATMUL_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h new file mode 100644 index 00000000..43cb7279 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/mul_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MUL_PARAMETER_H_ +#define NNACL_MUL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +} MulQuantArg; + +#endif // NNACL_MUL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h new file mode 100644 index 00000000..b112246f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nllloss_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NLLLOSS_PARAMETER_H_ +#define NNACL_NLLLOSS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct NLLLossParameter { + OpParameter op_parameter_; + ReductionType reduction_type_; +} NLLLossParameter; + +#endif // NNACL_NLLLOSS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c new file mode 100644 index 00000000..9cc3a1a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/nnacl_common.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +float ShortToFloat32(uint16_t src_value) { + const float32_bits magic = {113 << 23}; + const unsigned int shifted_exp = 0x7c00 << 13; + float32_bits o; + + o.u = (src_value & 0x7fff) << 13; + unsigned int exp = shifted_exp & o.u; + o.u += (127 - 15) << 23; + + if (exp == shifted_exp) { + o.u += (128 - 16) << 23; + } else if (exp == 0) { + o.u += 1 << 23; + o.f -= magic.f; + } + + o.u |= (src_value & 0x8000) << 16; + return o.f; +} + +uint16_t Float32ToShort(float src_value) { + float32_bits src_value_bits; + src_value_bits.f = src_value; + uint16_t res = 0; + // mantissa + res += (src_value_bits.u >> 13); + // exponent + res += (src_value_bits.u >> 13) & 0x3fc00; + res -= (127 - 15) << 13; + + // sign + res |= (src_value_bits.u & 0x80000000) >> 16; + return res; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h new file mode 100644 index 00000000..df2a9a0a --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_common.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_COMMON_H_ +#define NNACL_NNACL_COMMON_H_ + +#include +#include "nnacl_c/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t DataTypeCSize(TypeIdC type) { + switch (type) { + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeFloat: + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeFloat16: + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeComplex64: + return sizeof(float) + sizeof(float); + case kNumberTypeComplex128: + return sizeof(double) + sizeof(double); + case kNumberTypeBool: + return sizeof(bool); + case kObjectTypeString: + return sizeof(char); + case kObjectTypeTensorType: + return 0; + case kMetaTypeTypeType: + return sizeof(int); + default: + return 0; + } +} + +static inline void ComputeStrides(const int *shape, int *strides, const int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count, + int *in_count) { + *out_count = 1; + *in_count = 1; + for (int i = 0; i < shape_size; i++) { + if (i < axis) { + *out_count = (*out_count) * shape[i]; + } + if (i == axis) { + *axis_count = shape[axis]; + } + if (i > axis) { + *in_count = (*in_count) * shape[i]; + } + } +} + +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; +static const unsigned int FP32_EXPONENT_MAX = 255; +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; +static const int FP16_SHIFT = 13; +float ShortToFloat32(uint16_t src_value); +uint16_t Float32ToShort(float src_value); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_COMMON_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c new file mode 100644 index 00000000..20815a16 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/nnacl_utils.h" +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +#include +#endif + +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h new file mode 100644 index 00000000..dfc19878 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_UTILS_H_ +#define NNACL_NNACL_UTILS_H_ + +#include +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +#ifdef DEBUG +#include +#define NNACL_ASSERT(f) assert(f) +#else +#define NNACL_ASSERT(f) ((void)0) +#endif + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h new file mode 100644 index 00000000..06557138 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/non_max_suppression_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ +#define NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct NMSParameter { + // Primitive parameter + OpParameter op_parameter_; + int center_point_box_; +} NMSParameter; + +#endif // NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h new file mode 100644 index 00000000..b1650c36 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/one_hot_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ONE_HOT_PARAMETER_H_ +#define NNACL_ONE_HOT_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct OneHotParameter { + OpParameter op_parameter_; + int axis_; +} OneHotParameter; + +#endif // NNACL_ONE_HOT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h new file mode 100644 index 00000000..e1a9c40c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_base.h @@ -0,0 +1,802 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ + +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_ARM +#include +#endif + +#define C0NUM 0 +#define C1NUM 1 +#define C2NUM 2 +#define C3NUM 3 +#define C4NUM 4 +#define C5NUM 5 +#define C6NUM 6 +#define C7NUM 7 +#define C8NUM 8 +#define C9NUM 9 +#define C10NUM 10 +#define C11NUM 11 +#define C12NUM 12 +#define C13NUM 13 +#define C14NUM 14 +#define C15NUM 15 +#define C16NUM 16 +#define C17NUM 17 +#define C18NUM 18 +#define C19NUM 19 +#define C20NUM 20 +#define C21NUM 21 +#define C22NUM 22 +#define C23NUM 23 +#define C24NUM 24 +#define C28NUM 28 +#define C32NUM 32 +#define C36NUM 36 +#define C40NUM 40 +#define C44NUM 44 +#define C48NUM 48 +#define C56NUM 56 +#define C64NUM 64 +#define C128NUM 128 +#define C150NUM 150 +#define C256NUM 256 +#define C512NUM 512 +#define C1500NUM 1500 +#define TILE_NUM 8 +#define MAX_SPLIT_NUM 2048 + +#define FP16_DATA_TYPE_LEN 2 + +#ifndef MS_UNLIKELY +#ifdef _MSC_VER +#define MS_UNLIKELY(x) (x) +#else +#define MS_UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif +#endif + +#ifndef MS_LIKELY +#ifdef _MSC_VER +#define MS_LIKELY(x) (x) +#else +#define MS_LIKELY(x) __builtin_expect(!!(x), 1) +#endif +#endif + +#define NNACL_MIN(x, y) ((x) < (y) ? (x) : (y)) +#define NNACL_MAX(x, y) ((x) > (y) ? (x) : (y)) + +#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) +#define MSCEIL(x) (int)((x) + (((x) - (int)(x)) > 0 ? 1 : 0)) + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define DOWN_DIV(x, y) ((x) / (y)) +#define DOWN_ROUND(x, y) ((x) / (y) * (y)) + +#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right)) +#define SIZE_MUL_OVERFLOW(x, y) (((x) == 0) ? false : (SIZE_MAX / (x)) < (y)) +#define INT_MUL_OVERFLOW(x, y) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y)))) + +#define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y)))) + +#define INT_ADD_OVERFLOW(x, y) (INT_MAX - (x)) < (y) + +#define INT_ADD_OVERFLOW_THRESHOLD(x, y, threshold) ((threshold) - (x)) < (y) + +#define MALLOC_MAX_SIZE (2000 * 1024 * 1024) + +#define COMM_SHAPE_SIZE 4 +#define MAX_SHAPE_SIZE 8 + +#define OUTPUT_INDEX 0 +#define FIRST_INPUT 0 +#define SECOND_INPUT 1 +#define THIRD_INPUT 2 +#define FOURTH_INPUT 3 +#define FIFTH_INPUT 4 +#define SIXTH_INPUT 5 +#define SEVENTH_INPUT 6 +#define EIGHTH_INPUT 7 +#define NINTH_INPUT 8 + +#define ONE_TENSOR 1 +#define TWO_TENSOR 2 +#define THREE_TENSOR 3 +#define FOUR_TENSOR 4 +#define FIVE_TENSOR 5 + +#define Index0 0 +#define Index1 1 +#define Index2 2 +#define Index3 3 +#define Index4 4 +#define Index5 5 +#define Index6 6 +#define Index7 7 +#define Index8 8 +#define Index9 9 + +#define Num0 0 +#define Num1 1 +#define Num2 2 +#define Num3 3 +#define Num4 4 +#define Num5 5 +#define Num6 6 +#define Num7 7 +#define Num8 8 +#define Num9 9 + +#define DIMENSION_0D 0 +#define DIMENSION_1D 1 +#define DIMENSION_2D 2 +#define DIMENSION_3D 3 +#define DIMENSION_4D 4 +#define DIMENSION_5D 5 +#define DIMENSION_6D 6 +#define DIMENSION_7D 7 +#define DIMENSION_8D 8 +#define DIMENSION_9D 9 +#define DIMENSION_10D 10 +#define DIMENSION_11D 11 +#define kInputIndex 0 +#define kWeightIndex 1 +#define kBiasIndex 2 +#define kOutputIndex 0 +#define kNHWC_N 0 +#define kNHWC_H 1 +#define kNHWC_W 2 +#define kNHWC_C 3 +#define kNCHW_N 0 +#define kNCHW_C 1 +#define kNCHW_H 2 +#define kNCHW_W 3 +#define kHWCN_C 2 +#define kHWNC_N 2 +#define kHWCN_N 3 +#define kNDHWC_N 0 +#define kNDHWC_D 1 +#define kNDHWC_H 2 +#define kNDHWC_W 3 +#define kNDHWC_C 4 +#define kInputSize1 2 +#define kInputSize2 3 +#define MAX_AXIS_SIZE 6 +#define MAX_LEN 256 +#define MAX_THREAD_NUM 64 +#define FLT16_MAX 65504 +#define kDefaulLiteMaxSpinCount 300000 +#define kDefaulLiteMinSpinCount 1 +#define kDefaulLiteIosSpinCount 1 +#define DEFAULT_GROUP_NAME_LEN 101 +#define kValueThreshold6 6 + +#define INVALID_SHAPE -1 + +#define CLARGSINDEX0 0 +#define CLARGSINDEX1 1 +#define CLARGSINDEX2 2 +#define CLARGSINDEX3 3 +#define CLARGSINDEX4 4 +#define CLARGSINDEX5 5 +#define CLARGSINDEX6 6 +#define CLARGSINDEX7 7 +#define CLARGSINDEX8 8 +#define CLARGSINDEX9 9 + +#define CLIDX_X 0 +#define CLIDX_Y 1 +#define CLIDX_Z 2 +#define CLIDX_W 3 + +#define RELU6_MIN_VAL 0 +#define RELU6_MAX_VAL 6 + +/* index for primitive_type & activation_type */ +#define TC_PTYPE(primitive_type) (primitive_type << 16) +#define TC_ATYPE(activation_type) (activation_type) +#define TC_TYPE(primitive_type, activation_type) (TC_PTYPE(primitive_type) + TC_ATYPE(activation_type)) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) + +#if ENABLE_HIGH_PERFORMANCE +#define NNACL_CHECK_TRUE_RET(value, errcode) +#define NNACL_CHECK_TRUE_RET_VOID(value) +#define NNACL_CHECK_FALSE(value, errcode) +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) +#define NNACL_CHECK_ZERO_RETURN(val) +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) +#define NNACL_CHECK_MALLOC_SIZE(val) +#else +#define NNACL_CHECK_TRUE_RET(value, errcode) \ + do { \ + if (!(value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_TRUE_RET_VOID(value) \ + do { \ + if (!(value)) { \ + return; \ + } \ + } while (0) + +// Check whether value is false, if not return 'errcode' +#define NNACL_CHECK_FALSE(value, errcode) \ + do { \ + if ((value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_MUL_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_ADD_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_MALLOC_SIZE(malloc_size) \ + NNACL_CHECK_FALSE((malloc_size) > MALLOC_MAX_SIZE, NNACL_MALLOC_SIZE_INVALID) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) \ + do { \ + if ((val) == 0) { \ + return NNACL_ERR; \ + } \ + } while (0) + +#define NNACL_CHECK_ZERO_RETURN(val) \ + do { \ + if ((val) == 0) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) +#endif + +enum PrimType { + PrimType_NONE = 0, + PrimType_Abs = 1, + PrimType_Activation = 2, + PrimType_ActivationGrad = 3, + PrimType_Adam = 4, + PrimType_AddFusion = 5, + PrimType_AdderFusion = 6, + PrimType_AddGrad = 7, + PrimType_AddN = 8, + PrimType_All = 9, + PrimType_ApplyMomentum = 10, + PrimType_ArgMaxFusion = 11, + PrimType_ArgMinFusion = 12, + PrimType_Assert = 13, + PrimType_Assign = 14, + PrimType_AssignAdd = 15, + PrimType_AudioSpectrogram = 16, + PrimType_AvgPoolFusion = 17, + PrimType_AvgPoolGrad = 18, + PrimType_BatchNorm = 19, + PrimType_BatchNormGrad = 20, + PrimType_BatchToSpace = 21, + PrimType_BatchToSpaceND = 22, + PrimType_BiasAdd = 23, + PrimType_BinaryCrossEntropy = 24, + PrimType_BinaryCrossEntropyGrad = 25, + PrimType_BiasAddGrad = 26, + PrimType_BroadcastTo = 27, + PrimType_Cast = 28, + PrimType_Ceil = 29, + PrimType_Clip = 30, + PrimType_Concat = 31, + PrimType_Attention = 32, + PrimType_Conv2DBackpropFilterFusion = 33, + PrimType_Conv2DBackpropInputFusion = 34, + PrimType_Conv2DFusion = 35, + PrimType_Conv2dTransposeFusion = 36, + PrimType_Cos = 37, + PrimType_ConstantOfShape = 38, + PrimType_Crop = 39, + PrimType_CustomExtractFeatures = 40, + PrimType_CustomNormalize = 41, + PrimType_CustomPredict = 42, + PrimType_DeConv2DGradFilter = 43, + PrimType_Depend = 44, + PrimType_DepthToSpace = 45, + PrimType_DetectionPostProcess = 46, + PrimType_DivFusion = 47, + PrimType_DivGrad = 48, + PrimType_Dropout = 49, + PrimType_DropoutGrad = 50, + PrimType_Elu = 51, + PrimType_Eltwise = 52, + PrimType_Equal = 53, + PrimType_EmbeddingLookupFusion = 54, + PrimType_ExpFusion = 55, + PrimType_ExpandDims = 56, + PrimType_FakeQuantWithMinMaxVars = 57, + PrimType_FakeQuantWithMinMaxVarsPerChannel = 58, + PrimType_FftReal = 59, + PrimType_FftImag = 60, + PrimType_Flatten = 61, + PrimType_FlattenGrad = 62, + PrimType_Floor = 63, + PrimType_FloorDiv = 64, + PrimType_FloorMod = 65, + PrimType_Fill = 66, + PrimType_FullConnection = 67, + PrimType_FusedBatchNorm = 68, + PrimType_Gather = 69, + PrimType_GatherNd = 70, + PrimType_Greater = 71, + PrimType_GreaterEqual = 72, + PrimType_HashtableLookup = 73, + PrimType_InstanceNorm = 74, + PrimType_LayerNormFusion = 75, + PrimType_LeakyRelu = 76, + PrimType_Less = 77, + PrimType_LessEqual = 78, + PrimType_Log = 79, + PrimType_LogGrad = 80, + PrimType_LogicalAnd = 81, + PrimType_LogicalNot = 82, + PrimType_LogicalOr = 83, + PrimType_LpNormalization = 84, + PrimType_LRN = 85, + PrimType_LshProjection = 86, + PrimType_LSTM = 87, + PrimType_L2NormalizeFusion = 88, + PrimType_MatMulFusion = 89, + PrimType_Maximum = 90, + PrimType_MaximumGrad = 91, + PrimType_MaxPoolFusion = 92, + PrimType_MaxPoolGrad = 93, + PrimType_SwitchLayer = 94, + PrimType_Mfcc = 95, + PrimType_Minimum = 96, + PrimType_MinimumGrad = 97, + PrimType_Mod = 98, + PrimType_MulFusion = 99, + PrimType_MulGrad = 100, + PrimType_Neg = 101, + PrimType_NegGrad = 102, + PrimType_NotEqual = 103, + PrimType_NonMaxSuppression = 104, + PrimType_OneHot = 105, + PrimType_OnesLike = 106, + PrimType_PadFusion = 107, + PrimType_PartialFusion = 108, + PrimType_PowerGrad = 109, + PrimType_PowFusion = 110, + PrimType_PriorBox = 111, + PrimType_PReLUFusion = 112, + PrimType_QuantDTypeCast = 113, + PrimType_Rank = 114, + PrimType_Range = 115, + PrimType_Reciprocal = 116, + PrimType_RealDiv = 117, + PrimType_ReduceFusion = 118, + PrimType_Reshape = 119, + PrimType_Resize = 120, + PrimType_ReverseSequence = 121, + PrimType_ReverseV2 = 122, + PrimType_Rfft = 123, + PrimType_ROIPooling = 124, + PrimType_Round = 125, + PrimType_Rsqrt = 126, + PrimType_ScaleFusion = 127, + PrimType_ScatterNd = 128, + PrimType_SGD = 129, + PrimType_Shape = 130, + PrimType_SigmoidCrossEntropyWithLogits = 131, + PrimType_SigmoidCrossEntropyWithLogitsGrad = 132, + PrimType_Sin = 133, + PrimType_SkipGram = 134, + PrimType_SliceFusion = 135, + PrimType_SmoothL1Loss = 136, + PrimType_SmoothL1LossGrad = 137, + PrimType_Softmax = 138, + PrimType_SoftmaxCrossEntropyWithLogits = 139, + PrimType_SpaceToBatch = 140, + PrimType_SpaceToBatchND = 141, + PrimType_SpaceToDepth = 142, + PrimType_SparseSoftmaxCrossEntropyWithLogits = 143, + PrimType_SparseToDense = 144, + PrimType_Split = 145, + PrimType_Sqrt = 146, + PrimType_Squeeze = 147, + PrimType_Square = 148, + PrimType_SquaredDifference = 149, + PrimType_Stack = 150, + PrimType_StridedSlice = 151, + PrimType_SubFusion = 152, + PrimType_SubGrad = 153, + PrimType_Switch = 154, + PrimType_TensorListFromTensor = 155, + PrimType_TensorListGetItem = 156, + PrimType_TensorListReserve = 157, + PrimType_TensorListSetItem = 158, + PrimType_TensorListStack = 159, + PrimType_TileFusion = 160, + PrimType_TopKFusion = 161, + PrimType_Transpose = 162, + PrimType_Unique = 163, + PrimType_UnsortedSegmentSum = 164, + PrimType_Unsqueeze = 165, + PrimType_Unstack = 166, + PrimType_LSTMGrad = 167, + PrimType_Where = 168, + PrimType_ZerosLike = 169, + PrimType_Select = 170, + PrimType_ScatterNdUpdate = 171, + PrimType_GRU = 172, + PrimType_NonZero = 173, + PrimType_InvertPermutation = 174, + PrimType_Size = 175, + PrimType_RandomStandardNormal = 176, + PrimType_CropAndResize = 177, + PrimType_Erf = 178, + PrimType_StridedSliceGrad = 179, + PrimType_IsFinite = 180, + PrimType_LinSpace = 181, + PrimType_UniformReal = 182, + PrimType_AbsGrad = 183, + PrimType_RsqrtGrad = 184, + PrimType_SqrtGrad = 185, + PrimType_LayerNormGrad = 186, + PrimType_ResizeGrad = 187, + PrimType_Splice = 188, + PrimType_LogSoftmax = 189, + PrimType_Call = 190, + PrimType_Custom = 191, + PrimType_CumSum = 192, + PrimType_SplitWithOverlap = 193, + PrimType_GenOP = 194, + PrimType_RaggedRange = 195, + PrimType_GLU = 196, + PrimType_TensorArray = 197, + PrimType_TensorArrayRead = 198, + PrimType_TensorArrayWrite = 199, + PrimType_Affine = 200, + PrimType_AllGather = 201, + PrimType_ReduceScatter = 202, + PrimType_DynamicQuant = 203, + PrimType_LSTMGradData = 204, + PrimType_LSTMGradWeight = 205, + PrimType_RandomNormal = 206, + PrimType_NLLLoss = 207, + PrimType_NLLLossGrad = 208, + PrimType_FormatTranspose = 209, + PrimType_GatherD = 210, + PrimType_GroupNormFusion = 211, + PrimType_Log1p = 212, + PrimType_TensorScatterAdd = 213, + PrimType_SparseFillEmptyRows = 214, + PrimType_SparseReshape = 215, + PrimType_SparseSegmentSum = 216, + PrimType_ScatterElements = 217, + PrimType_Triu = 218, + PrimType_Tril = 219, + PrimType_AdamWeightDecay = 220, + PrimType_FillV2 = 221, + PrimType_MIN = PrimType_NONE, + PrimType_MAX = PrimType_FillV2 + 1, + + // inner operators. + PrimType_Inner_ToFormat = 10000, + PrimType_Inner_GltextureToOpencl = 10001, + PrimType_Inner_Identity = 10002, + PrimType_Inner_ShapeFusion = 10003, + PrimType_Inner_GraphKernel = 10004, + PrimType_Inner_SplitReduceConcatFusion = 10005, + PrimType_Inner_EncoderLayer = 10006, + PrimType_Inner_FseDecode = 10007, + PrimType_Inner_DecoderLayer = 10008, + PrimType_Inner_UsePastEmbedding = 10009, + PrimType_Inner_CustomGru = 10010, + PrimType_Inner_CastGatherReduceFusion = 10011, + PrimType_Inner_ReduceConcatFusion = 10012, + PrimType_Inner_AclCustomOp = 10013, + PrimType_Inner_CustomMaskedFill = 10014, + PrimType_Inner_CustomTensorScatterMax = 10015, + PrimType_Inner_CustomIsInf = 10016, + PrimType_Inner_Conv3D = 10017, + PrimType_Inner_GridSampler = 10018, + PrimType_InnerOpMax, + PrimType_InnerOpMin = PrimType_Inner_ToFormat +}; + +typedef enum FormatC { + DEFAULT_FORMAT = -1, + Format_NCHW = 0, + Format_NHWC = 1, + Format_NHWC4 = 2, + Format_HWKC = 3, + Format_HWCK = 4, + Format_KCHW = 5, + Format_CKHW = 6, + Format_KHWC = 7, + Format_CHWK = 8, + Format_HW = 9, + Format_HW4 = 10, + Format_NC = 11, + Format_NC4 = 12, + Format_NC4HW4 = 13, + Format_NONE = 14, // The origin Format_NUM_OF_FORMAT can't be used. + Format_NCDHW = 15, + Format_NWC = 16, + Format_NCW = 17, + Format_NDHWC = 18, + Format_NC8HW8 = 19, + Format_NC16HW16 = 20, + Format_MAX, + Format_MIN = Format_NCHW +} FormatC; + +typedef enum TypeIdC { + kTypeUnknown = 0, + kMetaTypeBegin = kTypeUnknown, + kMetaTypeType, // Type + kMetaTypeAny, + kMetaTypeObject, + kMetaTypeTypeType, // TypeType + kMetaTypeProblem, + kMetaTypeExternal, + kMetaTypeNone, + kMetaTypeNull, + kMetaTypeEllipsis, + kMetaTypeEnd, + // + // Object types + // + kObjectTypeBegin = kMetaTypeEnd, + kObjectTypeNumber, + kObjectTypeString, + kObjectTypeList, + kObjectTypeTuple, + kObjectTypeSlice, + kObjectTypeKeyword, + kObjectTypeTensorType, + kObjectTypeRowTensorType, + kObjectTypeCOOTensorType, + kObjectTypeUndeterminedType, + kObjectTypeClass, + kObjectTypeDictionary, + kObjectTypeFunction, + kObjectTypeJTagged, + kObjectTypeSymbolicKeyType, + kObjectTypeEnvType, + kObjectTypeRefKey, + kObjectTypeRef, + kObjectTypeEnd, + // + // Number Types + // + kNumberTypeBegin = kObjectTypeEnd, + kNumberTypeBool, + kNumberTypeInt, + kNumberTypeInt8, + kNumberTypeInt16, + kNumberTypeInt32, + kNumberTypeInt64, + kNumberTypeUInt, + kNumberTypeUInt8, + kNumberTypeUInt16, + kNumberTypeUInt32, + kNumberTypeUInt64, + kNumberTypeFloat, + kNumberTypeFloat16, + kNumberTypeFloat32, + kNumberTypeFloat64, + kNumberTypeDouble, + kNumberTypeComplex, + kNumberTypeComplex64, + kNumberTypeComplex128, + kNumberTypeInt4, + kNumberTypeGLUInt, + kNumberTypeEnd, +} TypeIdC; + +typedef enum DataOrder { + RowMajor, + ColMajor, +} DataOrder; + +typedef struct OpParameter { + char name_[100]; + int type_; + int thread_num_; + int quant_type_; + bool is_train_session_; + bool is_zero_shape_; + void (*destroy_func_)(struct OpParameter *param); +} OpParameter; + +typedef struct QuantArg { + float scale_; + int32_t zp_; +} QuantArg; + +typedef struct QuantMulArg { + int32_t multiplier_; + int left_shift_; + int right_shift_; +} QuantMulArg; + +typedef enum ReductionType { Reduction_Sum, Reduction_Mean, Reduction_None } ReductionType; +typedef enum ActType { + ActType_No = 0, + ActType_Relu = 1, + ActType_Sigmoid = 2, + ActType_Relu6 = 3, + ActType_Elu = 4, + ActType_LeakyRelu = 5, + ActType_Abs = 6, + ActType_Relu1 = 7, + ActType_Softsign = 8, + ActType_Softplus = 9, + ActType_Tanh = 10, + ActType_Selu = 11, + ActType_HSwish = 12, + ActType_HSigmoid = 13, + ActType_ThresholdRelu = 14, + ActType_Linear = 15, + ActType_HardTanh = 16, + ActType_Sign = 17, + ActType_Swish = 18, + ActType_Gelu = 19, + ActType_FastGelu = 20, + ActType_Unknown = 21 +} ActType; +typedef enum PadType { Pad_pad, Pad_same, Pad_valid } PadType; +typedef enum EltwiseType { Eltwise_PROD, Eltwise_SUM, Eltwise_MAXIMUM, Eltwise_UNKNOWN } EltwiseType; +typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; + +typedef enum PaddingModeC { + PaddingMode_Constant, + PaddingMode_Reflect, + PaddingMode_Symmetric, + PaddingMode_Mode_Reserved, +} PaddingModeC; + +typedef enum ElementwiseModeC { + Elementwise_Not = 0, + Elementwise_Per_Channel = 1, + Elementwise_Per_Num = 2 +} ElementwiseModeC; + +typedef enum QuantTypeC { + Quant_None = 0, + Quant_AwareTraining = 1, + Quant_WeightQuant = 2, + Quant_PostTraining = 3, + Quant_QuantWeight = 4, + Quant_QuantAll = 5, + Quant_QuantDynamic = 6, + Quant_Min = Quant_None, + Quant_Max = Quant_QuantDynamic +} QuantTypeC; + +typedef enum TensorCategoryC { + VarTensor, // common tensor + ConstTensor, // const tensor + ConstScalar, // const scalar + GraphInput, + GraphOutput +} TensorCategoryC; + +typedef enum ReduceModeC { + Reduce_Mean = 0, + Reduce_Max = 1, + Reduce_Min = 2, + Reduce_Prod = 3, + Reduce_Sum = 4, + Reduce_SumSquare = 5, + Reduce_ASum = 6, + Reduce_All = 7, + Reduce_L2 = 8, + Reduce_MIN = Reduce_Mean, + Reduce_MAX = Reduce_L2 +} ReduceModeC; + +typedef enum CalFixedMultiplierMode { + Method_No, + Method_SinglePrecision, + Method_DoublePrecision +} CalFixedMultiplierMode; + +#define VA_ARG_TUPLE_LEN 2 +static inline void offset_to_index_init(int offset, int cnt, ...) { + va_list valist; + va_start(valist, cnt); + int start = offset; + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + + *x = start % X; + start = start / X; + } + va_end(valist); +} + +static inline void offset_to_index_step(int cnt, ...) { + va_list valist; + int flag = 1; + va_start(valist, cnt); + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + if (flag) { + *x = (++*x != X) ? (flag = 0, *x) : (flag = 1, 0); + } + } + va_end(valist); +} + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in new file mode 100644 index 00000000..316a11ce --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/op_simd_header_file.h.in @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_@OP_NAME_UPPER@_SIMD_H_ +#define NNACL_@OP_NAME_UPPER@_SIMD_H_ + +#include "nnacl_c/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl_c/avx512/@OP_NAME_LOWER@_avx512.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl_c/avx/@OP_NAME_LOWER@_avx.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl_c/sse/@OP_NAME_LOWER@_sse.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl_c/neon/@OP_NAME_LOWER@_neon.h" +#endif + +#endif diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt new file mode 100644 index 00000000..9b190de7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/optimize/CMakeLists.txt @@ -0,0 +1,62 @@ +project(optimize) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) +include_directories(NNACL_DIR) + +########################### optimized files ########################### +file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c ${NNACL_DIR}/kernel/f16/*.c) +if(PLATFORM_ARM32) + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) +else() + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) + file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) + set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) +endif() + +set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) +set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) + +if(APPLE) + set_source_files_properties(${SDOT_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") + set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() +########################### share library build ######################## +list(APPEND FP16_FILES ${FP16_C_SRC}) +list(APPEND FP16_FILES ${FP16_NEON_SRC}) + +if(SUPPORT_TRAIN) + file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c) + list(APPEND FP16_FILES ${FP16_TRAIN_SRC}) +endif() +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() +if(MACHINE_LINUX_ARM64) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+fp16") +elseif(NOT PLATFORM_ARM32 AND NOT TARGET_HIMIX AND (NOT (TARGET_AOS_ARM AND TOOLCHAIN_NAME STREQUAL "gcc"))) + list(APPEND SDOT_FILES ${SDOT_SRC}) + add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) + add_dependencies(nnacl_optimize_mid fbs_src) + if(NOT TARGET_MIX210) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif() +endif() + +if(MSLITE_ENABLE_FP16) + add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) + add_dependencies(nnacl_fp16_mid fbs_src) + if(PLATFORM_ARM32) + target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) + endif() + if(TARGET_AOS_ARM) + if(TOOLCHAIN_NAME STREQUAL "gcc") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + endif() + endif() +endif() \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h new file mode 100644 index 00000000..701c3cda --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pack.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PACK_H_ +#define NNACL_PACK_H_ + +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/int8/pack_int8.h" + +#endif // NNACL_PACK_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h new file mode 100644 index 00000000..386e5c42 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pad_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PAD_PARAMETER_H_ +#define NNACL_PAD_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define MAX_PAD_SIZE 12 +#define DEFAULT_PAD_NDIMS 6 + +typedef struct PadQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quanr_args_; + int8_t *constant_value_; +} PadQuantArg; + +typedef struct PadParameter { + OpParameter op_parameter_; + int paddings_[MAX_PAD_SIZE]; + int pad_mode_; + float constant_value_; + int padding_length; +} PadParameter; + +#endif // NNACL_PAD_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h new file mode 100644 index 00000000..85c2e69b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/partial_fusion_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PARTIAL_FUSION_H_ +#define NNACL_PARTIAL_FUSION_H_ + +#include "nnacl_c/op_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/nnacl_utils.h" + +typedef struct PartialParameter { + OpParameter op_parameter_; + int sub_graph_index_; +} PartialParameter; + +#endif // NNACL_ARTITHMETIC_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h new file mode 100644 index 00000000..b2d962c2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pooling_parameter.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_POOLING_PARAMETER_H_ +#define NNACL_POOLING_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; + +typedef enum RoundType { RoundType_No, RoundType_Ceil, RoundType_Floor } RoundType; + +typedef struct PoolingParameter { + OpParameter op_parameter_; + PoolMode pool_mode_; + RoundType round_type_; + PadType pad_mode_; + ActType act_type_; + int avg_mode_; + bool global_; + int window_w_; + int window_h_; + int stride_w_; + int stride_h_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; +} PoolingParameter; + +typedef struct Pooling3DParameter { + PoolingParameter pooling_parameter_; + int window_d_; + int stride_d_; + int input_d_; + int output_d_; + int pad_f_; // front + int pad_b_; // back + bool count_include_pad_; + int divisor_override_; +} Pooling3DParameter; + +#endif // NNACL_POOLING_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h new file mode 100644 index 00000000..d62371a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/pow_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_POW_PARAMETER_H_ +#define NNACL_POW_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct PowQuantArg { + QuantArg in_args_; + QuantArg exp_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} PowQuantArg; + +typedef struct PowParameter { + OpParameter op_parameter_; + float power_; + float scale_; + float shift_; +} PowParameter; + +#endif // NNACL_POW_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h new file mode 100644 index 00000000..cf86950e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/predict_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PREDICT_PARAMETER_H_ +#define NNACL_PREDICT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct { + // Primitive parameter + OpParameter op_parameter_; + // other parameter + int output_num; + float weight_threshold; +} PredictParameter; + +typedef struct { + int label; + float weight; +} LabelInfo; +#endif // NNACL_PREDICT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h new file mode 100644 index 00000000..c9c3b1d2 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prelu_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PRELU_PARAMETER_H_ +#define NNACL_PRELU_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct PReluParameter { + OpParameter op_parameter_; + bool channel_shared_; +} PReluParameter; + +#endif // NNACL_PRELU_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h new file mode 100644 index 00000000..ad835599 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/prior_box_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PRIOR_BOX_PARAMETER_H_ +#define NNACL_PRIOR_BOX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct PriorBoxParameter { + OpParameter op_parameter_; + int32_t min_sizes_size; + int32_t min_sizes[MAX_SHAPE_SIZE]; + int32_t max_sizes_size; + int32_t max_sizes[MAX_SHAPE_SIZE]; + int32_t aspect_ratios_size; + float aspect_ratios[MAX_SHAPE_SIZE]; + float variances[COMM_SHAPE_SIZE]; + int32_t image_size_w; + int32_t image_size_h; + float step_w; + float step_h; + bool clip; + bool flip; + float offset; +} PriorBoxParameter; + +#endif // NNACL_PRIOR_BOX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h new file mode 100644 index 00000000..3efeaf94 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/random_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RNADOM_PARAMETER_H_ +#define NNACL_RNADOM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct RandomParam { + OpParameter op_parameter_; + int seed_; + int seed2_; +} RandomParam; + +typedef struct RandomNormalParam { + OpParameter op_parameter_; + float seed_; + float mean_; + float scale_; +} RandomNormalParam; + +#endif // NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h new file mode 100644 index 00000000..64166dec --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/range_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RANGE_PARAMETER_H_ +#define NNACL_RANGE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct RangeParameter { + OpParameter op_parameter_; + int dtype_; + int start_; + int limit_; + int delta_; +} RangeParameter; + +#endif // NNACL_RANGE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h new file mode 100644 index 00000000..40df590e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_PARAMETER_H_ +#define NNACL_REDUCE_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct ReduceParameter { + OpParameter op_parameter_; + bool keep_dims_; + int mode_; + bool reduce_to_end_; + float coeff; +} ReduceParameter; + +#endif // NNACL_REDUCE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h new file mode 100644 index 00000000..4ffe1b55 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reduce_scatter_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_SCATTER_PARAMETER_H_ +#define NNACL_REDUCE_SCATTER_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReduceScatterParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + int mode_; + + // other parameter + int rank_size_; +} ReduceScatterParameter; +#endif // NNACL_REDUCE_SCATTER_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h new file mode 100644 index 00000000..fd8bfa4c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reshape_parameter.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_RESHAHPE_PARAMETER_H_ +#define NNACL_RESHAHPE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReshapeQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} ReshapeQuantArg; + +typedef struct ReshapeParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_dim_; + int shape_[MAX_SHAPE_SIZE]; + + // other parameter + ReshapeQuantArg quant_para_; + int thread_count_; +} ReshapeParameter; + +#endif // NNACL_RESHAHPE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h new file mode 100644 index 00000000..950821b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/resize_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RESIZE_PARAMETER_H_ +#define NNACL_RESIZE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + int64_t new_height_; + int64_t new_width_; + int coordinate_transform_mode_; + float cubic_coeff_; + bool preserve_aspect_ratio_; +} ResizeParameter; + +typedef struct CropAndResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + float extrapolation_value_; +} CropAndResizeParameter; +#endif // NNACL_RESIZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h new file mode 100644 index 00000000..dc7c02a0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_PARAMETER_H_ +#define NNACL_REVERSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define REVERSE_SHAPE_MAX_SIZE 4 + +typedef struct ReverseParameter { + OpParameter op_parameter_; + int axis_[REVERSE_SHAPE_MAX_SIZE]; + int num_axis_; +} ReverseParameter; + +#endif // NNACL_REVERSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h new file mode 100644 index 00000000..11e29906 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/reverse_sequence_parameter.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_SEQUENCE_PARAMETER_H_ +#define NNACL_REVERSE_SEQUENCE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ReverseSequenceParameter { + // primitive parameter + OpParameter op_parameter_; + int seq_axis_; + int batch_axis_; + + // shape correlative + int input_shape0_[5]; + int output_shape_[5]; + int input_stride_[5]; + int output_stride_[5]; + + // other parameter + int ndim_; + int outer_count_; + int outer_stride_; + int inner_count_; + int inner_stride_; + int copy_byte_size_; + int total_data_size_; + bool is_seq_length_int32_; +} ReverseSequenceParameter; + +#endif // NNACL_REVERSE_SEQUENCE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h new file mode 100644 index 00000000..65ad061c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scale_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_H_ +#define NNACL_SCALE_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ScaleParameter { + OpParameter op_parameter_; + int axis_; + int activation_type_; +} ScaleParameter; + +typedef struct ScaleQuantParameter { + QuantMulArg scale_mul_arg_; + QuantMulArg offset_mul_arg_; + int input_zp_; + int scale_zp_; + int offset_zp_; + int output_zp_; + int output_activation_min_; + int output_activation_max_; +} ScaleQuantParameter; + +#endif // NNACL_SCALE_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h new file mode 100644 index 00000000..b1547dac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_elements_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_SCATTER_ELEMENTS_PARAMETER_H_ +#define NNACL_SCATTER_ELEMENTS_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct ScatterElementsParameter { + OpParameter op_parameter_; + int axis_; +} ScatterElementsParameter; + +#endif // NNACL_SCATTER_ELEMENTS_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h new file mode 100644 index 00000000..6a70ad86 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/scatter_nd_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCATTER_ND_PARAMETER_H_ +#define NNACL_SCATTER_ND_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct ScatterNDParameter { + OpParameter op_parameter; + int num_unit; + int unit_size; + int data_type_len; +} ScatterNDParameter; + +#endif // NNACL_SCATTER_ND_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h new file mode 100644 index 00000000..14ac9d56 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sequence_unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SequenceUnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} SequenceUnstackParameter; + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h new file mode 100644 index 00000000..d864ef5d --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sigmoid_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SIGMOID_PARAMETER_H_ +#define NNACL_SIGMOID_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SigmoidParameter { + // primitive parameter + OpParameter op_parameter_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + + // other parameter + SigmoidQuantArg quant_arg; + double alpha_; + int thread_count_; + int64_t offset_[MAX_SHAPE_SIZE]; + int64_t in_offset_[MAX_SHAPE_SIZE]; + int64_t axis_; + int input_dim_; + int element_num; +} SigmoidParameter; + +#endif // NNACL_SIGMOID_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h new file mode 100644 index 00000000..46bb751f --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/skip_gram_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SKIP_GRAM_PARAMETER_H_ +#define NNACL_SKIP_GRAM_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SkipGramParameter { + // primitive parameter + OpParameter op_parameter_; + bool include_all_ngrams; + int max_skip_size; + int ngram_size; +} SkipGramParameter; + +#endif // NNACL_SKIP_GRAM_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h new file mode 100644 index 00000000..ca5af2fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/slice_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SLICE_PARAMETER_H_ +#define NNACL_SLICE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SliceQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantMulArg multiplier_; +} SliceQuantArg; + +typedef struct SliceParameter { + OpParameter op_parameter_; + int32_t axis_[DIMENSION_8D]; +} SliceParameter; + +#endif // NNACL_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h new file mode 100644 index 00000000..428e0673 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/softmax_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SOFTMAX_PARAMETER_H_ +#define NNACL_SOFTMAX_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SoftmaxParameter { + OpParameter op_parameter_; + int32_t axis_; +} SoftmaxParameter; + +#endif // NNACL_SOFTMAX_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h new file mode 100644 index 00000000..fa337a8c --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/space_to_depth_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#define LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct SpaceToDepthParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t block_size_; + int32_t date_type_len; +} SpaceToDepthParameter; + +#endif // LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h new file mode 100644 index 00000000..2aa82ea0 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/sparse_to_dense_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPARSE_TO_DENSE_PARAMETER_H_ +#define NNACL_SPARSE_TO_DENSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct SparseToDenseParameter { + // primitive parameter + OpParameter op_parameter_; + bool validate_indices_; + bool is_scalar; + int index_num; + int output_num; + int output_stride[DIMENSION_4D]; +} SparseToDenseParameter; + +#endif // NNACL_SPARSE_TO_DENSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h new file mode 100644 index 00000000..1bb1b7b7 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/splice_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLICE_PARAMETER_H_ +#define NNACL_SPLICE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +typedef struct SpliceParameter { + OpParameter op_parameter_; + int context_dim_; + int forward_indexes_dim_; + int *context_; + int *forward_indexes_; + int output_dim_; +} SpliceParameter; +#endif // NNACL_SPLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h new file mode 100644 index 00000000..5365daac --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/split_parameter.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_PARAMETER_H_ +#define NNACL_SPLIT_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +#define SPLIT_STRIDES_SIZE 32 +#define SPLIT_MAX_SLICE_NUM 10 + +typedef struct SplitQuantArg { + QuantArg in_args_; + QuantArg out_args_[20]; + int output_activation_min_; + int output_activation_max_; +} SplitQuantArg; + +typedef struct SplitParameter { + // primitive parameter + OpParameter op_parameter_; + int num_split_; + int *split_sizes_; + int split_dim_; + + // shape correlative + int strides_[SPLIT_STRIDES_SIZE]; + + // other parameter + SplitQuantArg quant_arg_; + int n_dims_; + int split_count_; +} SplitParameter; + +typedef struct SplitWithOverlapParameter { + OpParameter op_parameter_; + int num_split_; + int split_dim_; + int ratio_[SPLIT_MAX_SLICE_NUM]; + int extend_top_[SPLIT_MAX_SLICE_NUM]; + int extend_bottom_[SPLIT_MAX_SLICE_NUM]; + + // other parameter + int element_bytes_; + int split_dim_size_; + int outer_total_dim_; + int inner_stride_; +} SplitWithOverlapParameter; + +#endif // NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h new file mode 100644 index 00000000..cfbaad63 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/squeeze_parameter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SQUEEZE_PARAMETER_H_ +#define NNACL_SQUEEZE_PARAMETER_H_ +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quantize.h" + +#define SQUEEZE_OFFSET_MAX_SIZE 4 + +typedef struct SqueezeQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quant_args_; +} SqueezeQuantArg; + +typedef struct SqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int axis_[8]; + size_t axis_size_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int offset_size_; + int64_t offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int64_t in_offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int input_dim_; + // other parameter + SqueezeQuantArg quant_arg; +} SqueezeParameter; + +#endif // NNACL_SQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h new file mode 100644 index 00000000..55d66a51 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/stack_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_STACK_PARAMETER_H_ +#define NNACL_STACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct StackParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t axis_; +} StackParameter; + +#endif // NNACL_STACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h new file mode 100644 index 00000000..3ff8618b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/strided_slice_parameter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_STRIDED_SLICE_PARAMETER_H_ +#define NNACL_STRIDED_SLICE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct StridedSliceParameter { + // primitive parameter + OpParameter op_parameter_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int isScale; + + // shape correlative + int in_shape_length_; + int in_shape_[MAX_SHAPE_SIZE]; + + // other parameter + int num_axes_; + TypeIdC data_type; + int begins_mask_; + int ends_mask_; + int ellipsisMask_; + int newAxisMask_; + int shrinkAxisMask_; +} StridedSliceParameter; + +#endif // NNACL_STRIDED_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h new file mode 100644 index 00000000..aca72845 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_array_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_ARRAY_PARAMETER_H_ +#define NNACL_TENSOR_ARRAY_PARAMETER_H_ +#include "nnacl_c/op_base.h" + +typedef struct TensorArrayParameter { + OpParameter op_parameter_; + bool dynamic_size_; + bool identical_element_shapes_; + int element_shape_[MAX_SHAPE_SIZE]; + int element_shape_size_; + int data_type_; +} TensorArrayParameter; + +#endif // NNACL_TENSOR_ARRAY_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h new file mode 100644 index 00000000..6d515a28 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_C_H_ +#define NNACL_TENSOR_C_H_ +#include "nnacl_c/op_base.h" + +typedef struct TensorC { + bool shape_changed_; + int data_type_; + int format_; + int category_; + void *data_; + size_t shape_size_; + int shape_[MAX_SHAPE_SIZE]; + char *name_; // only used in micro now. +} TensorC; + +#endif // NNACL_TENSOR_C_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c new file mode 100644 index 00000000..2a2564e8 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c @@ -0,0 +1,439 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/tensor_c_utils.h" +#include "nnacl_c/nnacl_common.h" + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + NNACL_CHECK_NULL_RETURN_ERR(inputs); + NNACL_CHECK_NULL_RETURN_ERR(outputs); + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + if (parameter == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if ((inputs_size != inputs_size_obj_0 && inputs_size != inputs_size_obj_1) || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size < inputs_size_obj || outputs_size < outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +void SetShapeTensor(TensorC *dst, const TensorC *src) { + for (size_t i = 0; i < src->shape_size_; i++) { + dst->shape_[i] = src->shape_[i]; + } + dst->shape_size_ = src->shape_size_; +} + +void SetShapeArray(TensorC *dst, const int *src, size_t src_size) { + for (size_t i = 0; i < src_size && i < MAX_SHAPE_SIZE; i++) { + dst->shape_[i] = src[i]; + } + dst->shape_size_ = src_size; +} + +void SetDataTypeFormat(TensorC *dst, const TensorC *src) { + dst->format_ = src->format_; + dst->data_type_ = src->data_type_; +} + +int NNACLGetBatch(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + return tensor->shape_[kNHWC_N]; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_N]; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWNC_N]; + case Format_CKHW: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetHeight(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_H]; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + return tensor->shape_[kNHWC_H]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[0]; + default: + return -1; + } +} +int NNACLGetWidth(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_W]; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_W]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetChannel(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + return tensor->shape_[kNCHW_C]; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_C]; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_C]; + case Format_CKHW: + case Format_CHWK: + return tensor->shape_[0]; + default: + return -1; + } +} + +void NNACLSetBatch(TensorC *tensor, int batch) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + tensor->shape_[kNHWC_N] = batch; + return; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_N] = batch; + return; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWNC_N] = batch; + return; + case Format_CKHW: + tensor->shape_[1] = batch; + return; + default: + return; + } +} + +void NNACLSetHeight(TensorC *tensor, int height) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_H] = height; + return; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + tensor->shape_[kNHWC_H] = height; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[0] = height; + return; + default: + return; + } +} + +void NNACLSetWidth(TensorC *tensor, int width) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_W] = width; + return; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_W] = width; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[1] = width; + return; + default: + return; + } +} + +void NNACLSetChannel(TensorC *tensor, int channel) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + tensor->shape_[kNCHW_C] = channel; + return; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_C] = channel; + return; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_C] = channel; + return; + case Format_CKHW: + case Format_CHWK: + tensor->shape_[0] = channel; + return; + default: + return; + } +} + +int NNACLGetSize(const TensorC *tensor) { + int element_num = NNACLGetElementNum(tensor); + int data_type_size = (int)DataTypeCSize(tensor->data_type_); + return element_num * data_type_size; +} + +int NNACLGetElementNum(const TensorC *tensor) { + if (tensor == NULL) { + return -1; + } + if (tensor->shape_size_ == 0) { + return 1; // scalar mode + } + int res = 1; + for (size_t i = 0; i < tensor->shape_size_; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(res, tensor->shape_[i], NNACL_ERRCODE_MUL_OVERFLOW); + res = res * tensor->shape_[i]; + } + + int c = NNACLGetChannel(tensor); + if (c == 0) { + return res; + } + if (tensor->format_ == Format_NC4HW4) { + res = res / c * UP_ROUND(c, C4NUM); + } + if (tensor->format_ == Format_NC8HW8) { + res = res / c * UP_ROUND(c, C8NUM); + } + return res; +} + +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index) { + int dim_size = -1; + if (index < tensor->shape_size_) { + dim_size = tensor->shape_[index]; + } + return dim_size; +} + +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2) { + if (tensor1->shape_size_ != tensor2->shape_size_) { + return false; + } + for (size_t i = 0; i < tensor1->shape_size_; i++) { + if (tensor1->shape_[i] != tensor2->shape_[i]) { + return false; + } + } + return true; +} + +bool NNACLIsConst(const TensorC *tensor) { + return (tensor->category_ == ConstTensor || tensor->category_ == ConstScalar) && tensor->data_ != NULL; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h new file mode 100644 index 00000000..29a8112e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORC_UTILS_H_ +#define NNACL_TENSORC_UTILS_H_ + +#include +#include "nnacl_c/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensor_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NNACLGetBatch(const TensorC *tensor); +int NNACLGetHeight(const TensorC *tensor); +int NNACLGetWidth(const TensorC *tensor); +int NNACLGetChannel(const TensorC *tensor); +void NNACLSetBatch(TensorC *tensor, int batch); +void NNACLSetHeight(TensorC *tensor, int height); +void NNACLSetWidth(TensorC *tensor, int width); +void NNACLSetChannel(TensorC *tensor, int channel); +int NNACLGetElementNum(const TensorC *tensor); +int NNACLGetSize(const TensorC *tensor); +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index); +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2); +bool NNACLIsConst(const TensorC *tensor); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_TENSORC_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h new file mode 100644 index 00000000..d0839275 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_C_H_ +#define NNACL_TENSORLIST_C_H_ + +#include "nnacl_c/tensor_c.h" + +typedef struct vvector { + int **shape_; // value of shapes + int *shape_size_; // size of shape + size_t size_; // number of shapes +} vvector; + +typedef struct TensorListC { + bool shape_changed_; + int data_type_; + int format_; + int shape_value_; + int tensors_data_type_; // element_data_type_, keep same as c++ + int max_elements_num_; + TensorC **tensors_; + size_t element_num_; + size_t element_shape_size_; + int element_shape_[MAX_SHAPE_SIZE]; +} TensorListC; + +#endif // NNACL_TENSORLIST_C_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c new file mode 100644 index 00000000..4a6122c3 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl_c/tensorlist_c_utils.h" + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape) { + // This function will create a new tensors_ + // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in + // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_. + + if (tensor_list->element_num_ == 0) { + return NNACL_OK; + } + if (((size_t)(tensor_list->element_num_)) != tensor_shape->size_) { + return NNACL_ERR; + } + tensor_list->tensors_data_type_ = dtype; + void *addr = malloc(tensor_list->element_num_ * sizeof(void *) + + tensor_list->element_num_ * sizeof(TensorC)); // free in infer_manager + if (addr == NULL) { + free(tensor_list->tensors_); + return NNACL_NULL_PTR; + } + memset(addr, 0, tensor_list->element_num_ * sizeof(void *) + tensor_list->element_num_ * sizeof(TensorC)); + tensor_list->tensors_ = (TensorC **)addr; + TensorC *tensors = (TensorC *)(tensor_list->tensors_ + tensor_list->element_num_); + for (size_t i = 0; i < tensor_list->element_num_; ++i) { + TensorC *tensor = tensors + i; + tensor_list->tensors_[i] = tensor; + tensor->format_ = Format_NHWC; + tensor->data_type_ = dtype; + ShapeSet(tensor->shape_, &(tensor->shape_size_), tensor_shape->shape_[i], (size_t)tensor_shape->shape_size_[i]); + } + return NNACL_OK; +} + +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size) { + if (*element_shape_size >= 255 || element_shape[0] == -1) { + ShapeSet(element_shape, element_shape_size, tmp, tmp_size); + return NNACL_OK; + } + if (*element_shape_size != tmp_size) { + return NNACL_ERR; + } + for (size_t j = 0; j < tmp_size; ++j) { + if (element_shape[j] >= 0 && tmp[j] >= 0 && element_shape[j] != tmp[j]) { + return NNACL_ERR; + } + element_shape[j] = element_shape[j] >= 0 ? element_shape[j] : tmp[j]; + } + return NNACL_OK; +} + +bool TensorListIsFullyDefined(const int *shape, size_t shape_size) { + for (size_t i = 0; i < shape_size; ++i) { + if (shape[i] < 0) { + return false; + } + } + return true; +} + +bool InferFlagTensorList(TensorC *tensorc) { + TensorListC *input_tensor_list = (TensorListC *)tensorc; + if (input_tensor_list->shape_value_ == -1) { + return false; + } + return true; +} diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h new file mode 100644 index 00000000..a69ecc12 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_C_UTILS_H_ +#define NNACL_TENSORLIST_C_UTILS_H_ + +#include +#include "nnacl_c/op_base.h" +#include "nnacl_c/tensorlist_c.h" +#include "nnacl_c/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape); +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size); +bool TensorListIsFullyDefined(const int *shape, size_t shape_size); +bool InferFlagTensorList(TensorC *tensor_list); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_TENSORLIST_C_UTILS_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h new file mode 100644 index 00000000..beb70c1b --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tensorlist_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_PARAMETER_H_ +#define NNACL_TENSORLIST_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TensorListParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_type_; + int element_dtype_; + + // other parameter + int num_element_; +} TensorListParameter; + +#endif // NNACL_ARG_TENSORLIST_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h new file mode 100644 index 00000000..d7ad9ef9 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/tile_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TILE_PARAMETER_H_ +#define NNACL_TILE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TileParameter { + OpParameter op_parameter_; + size_t dims_size_; + int dims_[MAX_SHAPE_SIZE]; + int multiples_[MAX_SHAPE_SIZE]; +} TileParameter; + +#endif // NNACL_TILE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h new file mode 100644 index 00000000..34a78453 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/transpose_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TRANSPOSE_PARAMETER_H_ +#define NNACL_TRANSPOSE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +// MAX_TRANSPOSE_SERIAL_SIZE = 64 * 3 * 512 * 512 +#define MAX_TRANSPOSE_SERIAL_SIZE 50331648 +#define MAX_TRANSPOSE_DIM_SIZE 20 +#define PERM_NUM_THREE 3 +#define PERM_NUM_FOUR 4 + +typedef struct TransposeParameter { + // primitive parameter + OpParameter op_parameter_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + size_t perm_size_; + bool conjugate_; + + // shape correlative + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + // other parameter + int num_axes_; + int data_num_; +} TransposeParameter; + +#endif // NNACL_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h new file mode 100644 index 00000000..b9b5a95e --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/triu_tril_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TRIU_TRIL_PARAMETER_H_ +#define NNACL_TRIU_TRIL_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct TriuParameter { + // Primitive parameter + OpParameter op_parameter_; +} TriuParameter; + +typedef struct TrilParameter { + // Primitive parameter + OpParameter op_parameter_; +} TrilParameter; + +#endif // NNACL_TRIU_TRIL_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h new file mode 100644 index 00000000..e7f6d643 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unsqueeze_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSQUEEZE_PARAMETER_H_ +#define NNACL_UNSQUEEZE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct UnSqueezeQuantArg { + int *output_shape_; + float alpha; + int axis_; + size_t input_num_; + QuantArg in_quant_args_; + QuantArg out_quant_args_; +} UnSqueezeQuantArg; + +typedef struct UnSqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int dims_[COMM_SHAPE_SIZE]; + int num_dim_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int64_t offset_[COMM_SHAPE_SIZE]; + int64_t axis_; + + // other parameter + UnSqueezeQuantArg quant_arg; + int thread_count_; +} UnSqueezeParameter; + +#endif // NNACL_UNSQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h new file mode 100644 index 00000000..0e328190 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSTACK_PARAMETER_H_ +#define NNACL_UNSTACK_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct UnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} UnstackParameter; + +#endif // NNACL_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h new file mode 100644 index 00000000..79b931b6 --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/upsample_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UPSAMPLE_PARAMETER_H_ +#define NNACL_UPSAMPLE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" +typedef struct { + // primitive parameter + OpParameter op_parameter_; + + // other parameter + int method_; // 0 for bilinear; 1 for nearest +} UpsampleParameter; + +#endif // NNACL_UPSAMPLE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h new file mode 100644 index 00000000..973569fa --- /dev/null +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl_c/where_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_WHERE_PARAMETER_H_ +#define NNACL_WHERE_PARAMETER_H_ + +#include "nnacl_c/op_base.h" + +typedef struct WhereParameter { + OpParameter op_parameter_; +} WhereParameter; + +#endif // NNACL_WHERE_PARAMETER_H_ diff --git a/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h b/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h index 30d190d5..7f55b281 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/lsh_projection.h @@ -19,7 +19,7 @@ #include -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/litert/lite_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/string/predict.h b/mindspore-lite/src/litert/kernel/cpu/string/predict.h index 768d1d4c..9f5f6fa4 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/predict.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/predict.h @@ -18,7 +18,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/predict_parameter.h" +#include "nnacl_c/predict_parameter.h" namespace mindspore::kernel { class PredictCPUKernel : public LiteKernel { diff --git a/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h b/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h index f5e7a29b..9b32619b 100644 --- a/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h +++ b/mindspore-lite/src/litert/kernel/cpu/string/skip_gram.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/string_utils.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc b/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc index 6a578081..fdc3a87c 100644 --- a/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc +++ b/mindspore-lite/src/litert/kernel/gpu/opencl/opencl_executor.cc @@ -16,7 +16,7 @@ #include "src/litert/kernel/gpu/opencl/opencl_executor.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/pack.h" +#include "nnacl_c/pack.h" #include "include/errorcode.h" namespace mindspore::lite::opencl { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h b/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h index 11bcbb78..dcec9c9b 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/activation.h @@ -21,7 +21,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::kernel { class ActivationOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h b/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h index 6277b480..0fdfe87c 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/argminmax.h @@ -19,8 +19,8 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/arg_min_max_parameter.h" -#include "nnacl/kernel/arg_min_max.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/kernel/arg_min_max.h" namespace mindspore::kernel { class ArgMinMaxOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc index 603ae559..5ffc86b0 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic.cc @@ -21,7 +21,7 @@ #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/arithmetic.cl.inc" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h index d819fb29..68c28468 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/arithmetic_self.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Ceil; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h b/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h index 1d5fa370..a8e6049f 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batch_to_space_nd.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::kernel { class BatchToSpaceNDOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc index c6019a43..6f240aeb 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.cc @@ -21,7 +21,7 @@ #include "src/litert/kernel/opencl/kernel/batchnorm.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/batchnorm.cl.inc" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h index 44563b8d..94f38404 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/batchnorm.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" namespace mindspore::kernel { class BatchNormOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h b/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h index 3876a745..15094387 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/concat.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h index f10a0026..4ad08678 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.h @@ -22,7 +22,7 @@ #include "src/tensor.h" #include "src/litert/kernel/opencl/opencl_kernel.h" #include "schema/model_generated.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "schema/ops_generated.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc index d38d1eff..e8efae25 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.cc @@ -17,7 +17,7 @@ #include "src/litert/kernel/opencl/kernel/conv2d_transpose.h" #include #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/cl/conv2d_transpose.cl.inc" #include "src/litert/kernel/opencl/utils.h" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h index 3a0b6491..bb756aef 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d_transpose.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h b/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h index 0abaa292..117ee024 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/crop.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" namespace mindspore::kernel { class CropOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc index 6e3a1888..6b923fb7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.cc @@ -22,8 +22,8 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/common_func_fp32.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "src/litert/kernel/opencl/cl/depthwise_conv2d.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h index 26172d7a..909f61d8 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/depthwise_conv2d.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::lite::opencl::MemType; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h index 7b788413..f81873a6 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fill.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ #include -#include "nnacl/base/fill_base.h" +#include "nnacl_c/base/fill_base.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc index 99efc6f8..9f61a966 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc @@ -17,7 +17,7 @@ #include #include #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/kernel/fullconnection.h" #include "src/litert/kernel/opencl/utils.h" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h index c895a9db..da851b05 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class FullConnectionOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc index 16e067e1..386796c7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.cc @@ -17,9 +17,9 @@ #include #include "src/litert/kernel/opencl/utils.h" #include "include/errorcode.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/scale_parameter.h" #include "src/litert/infer_manager.h" using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h index 30e042af..54900401 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/fusion_eltwise.h @@ -28,7 +28,7 @@ #include "src/litert/kernel/opencl/kernel/arithmetic_self.h" #include "src/litert/kernel/opencl/kernel/to_format.h" #include "schema/ops_generated.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::ActivationType; using mindspore::schema::PrimitiveType; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h b/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h index b42ccf03..98da3725 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/gather.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::kernel { class GatherOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc index cca73b16..48cc00ef 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/int8/arithmetic_int8.cc @@ -16,12 +16,12 @@ #include "src/litert/kernel/opencl/kernel/int8/arithmetic_int8.h" #include -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/int8/arithmetic.cl.inc" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc index 44982a34..c7f4cfd4 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/layer_norm.cc @@ -19,7 +19,7 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/kernel/layer_norm.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/layer_norm.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h b/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h index 07f971aa..3981a8c1 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/matmul.h @@ -20,7 +20,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" #include "src/common/utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::kernel { class MatMulOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h b/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h index d1bec9fb..dcec9932 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/one_hot.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" namespace mindspore::kernel { class OneHotOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h b/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h index 195ec7c1..253a9adf 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/pad.h @@ -22,7 +22,7 @@ #include "src/tensor.h" #include "src/litert/kernel/opencl/opencl_kernel.h" #include "schema/model_generated.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::kernel { class PadOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h b/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h index f65e93fa..5e6ef777 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/pooling2d.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" namespace mindspore::kernel { class PoolingOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/power.h b/mindspore-lite/src/litert/kernel/opencl/kernel/power.h index d2c3eae2..33f02aac 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/power.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/power.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_POWER_H_ #include -#include "nnacl/fp32/power_fp32.h" +#include "nnacl_c/fp32/power_fp32.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc index 501bcf56..8c319140 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/prelu.cc @@ -17,13 +17,13 @@ */ #include "src/litert/kernel/opencl/kernel/prelu.h" -#include +#include "nnacl_c/prelu_parameter.h" #include #include #include "src/litert/kernel/opencl/cl/prelu.cl.inc" #include "src/litert/kernel_registry.h" #include "include/errorcode.h" -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h b/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h index 1b093caa..f66d28b0 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/reduce.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::kernel { class ReduceOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h b/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h index 74d1098a..8f7d9143 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/resize.h @@ -20,7 +20,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::kernel { class ResizeOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc index c63ef367..9fdfa3f7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.cc @@ -20,7 +20,7 @@ #include #include "schema/model_generated.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl_c/fp32/common_func_fp32.h" #include "src/litert/kernel/opencl/utils.h" #include "src/litert/kernel/opencl/cl/scale.cl.inc" diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h index bac659e0..09a70acf 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/scale.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SCALE_H_ #include -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc index bafaa336..e15effb7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.cc @@ -19,7 +19,7 @@ #include "include/errorcode.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel/opencl/cl/softmax.cl.inc" using mindspore::kernel::KERNEL_ARCH::kGPU; diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h index 84f22396..22f598da 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/softmax.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl_c/fp32/softmax_fp32.h" namespace mindspore::kernel { class SoftmaxOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h index 6dcd0ae6..bde3a88b 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_batch_nd.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" namespace mindspore::kernel { class SpaceToBatchNDOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h index 75ecf703..e00c5177 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/space_to_depth.h @@ -21,7 +21,7 @@ #include #include "src/litert/lite_kernel.h" #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/space_to_depth_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" namespace mindspore::kernel { class SpaceToDepthOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h b/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h index 63181ed7..306411b7 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/sparse_to_dense.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" namespace mindspore::kernel { class SparseToDenseOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/split.h b/mindspore-lite/src/litert/kernel/opencl/kernel/split.h index 9f094f45..80ebb4e1 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/split.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/split.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::kernel { class SplitOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h b/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h index 4e7bef91..9e7f1be4 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/stack.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::kernel { class StackOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h b/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h index dbc78f75..129de4fa 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/strided_slice.h @@ -19,7 +19,7 @@ #include #include "src/litert/kernel/opencl/opencl_kernel.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" namespace mindspore::kernel { class StridedSliceOpenCLKernel : public OpenCLKernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h b/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h index 64665735..f7cc7c06 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/transpose.h @@ -19,7 +19,7 @@ #include #include "src/litert/lite_kernel.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "src/litert/kernel/opencl/opencl_kernel.h" namespace mindspore::kernel { diff --git a/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc b/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc index d1d886dd..b947c2c6 100644 --- a/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc +++ b/mindspore-lite/src/litert/kernel/opencl/kernel/winograd.cc @@ -17,8 +17,8 @@ #include "src/litert/kernel/opencl/kernel/winograd.h" #include #include "src/litert/kernel/opencl/cl/winograd.cl.inc" -#include "nnacl/base/minimal_filtering_generator.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/base/minimal_filtering_generator.h" +#include "nnacl_c/errorcode.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc b/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc index 72186219..b4e3ae20 100644 --- a/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc +++ b/mindspore-lite/src/litert/kernel/opencl/opencl_fusion.cc @@ -27,13 +27,13 @@ #include "include/errorcode.h" #include "schema/ops_generated.h" #include "src/common/utils.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" using mindspore::schema::ActivationType; using mindspore::schema::ActivationType_LEAKY_RELU; diff --git a/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h b/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h index c8c016bd..86d07453 100644 --- a/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h +++ b/mindspore-lite/src/litert/kernel/opencl/opencl_kernel.h @@ -29,7 +29,7 @@ #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" #include "src/litert/tensor_category.h" #include "src/litert/kernel/opencl/utils.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; diff --git a/mindspore-lite/src/litert/kernel/opencl/utils.h b/mindspore-lite/src/litert/kernel/opencl/utils.h index 1b709609..d392c7a1 100644 --- a/mindspore-lite/src/litert/kernel/opencl/utils.h +++ b/mindspore-lite/src/litert/kernel/opencl/utils.h @@ -22,7 +22,7 @@ #include #include "CL/cl2.hpp" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/lite_kernel.h" #include "src/common/utils.h" #include "src/litert/kernel/opencl/opencl_kernel.h" diff --git a/mindspore-lite/src/litert/kernel_exec_util.cc b/mindspore-lite/src/litert/kernel_exec_util.cc index 82ffb7e2..b4a885ee 100644 --- a/mindspore-lite/src/litert/kernel_exec_util.cc +++ b/mindspore-lite/src/litert/kernel_exec_util.cc @@ -20,7 +20,7 @@ #include #include #include "src/executor/sub_graph_kernel.h" -#include "nnacl/call_parameter.h" +#include "nnacl_c/call_parameter.h" #if GPU_OPENCL #include "src/litert/kernel/opencl/opencl_subgraph.h" #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" diff --git a/mindspore-lite/src/litert/kernel_registry.cc b/mindspore-lite/src/litert/kernel_registry.cc index f69086e6..d28eaf5b 100644 --- a/mindspore-lite/src/litert/kernel_registry.cc +++ b/mindspore-lite/src/litert/kernel_registry.cc @@ -22,7 +22,7 @@ #endif #include "src/common/ops/populate/populate_register.h" #include "src/common/version_manager.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #if defined(ENABLE_FP16) && defined(ENABLE_ARM) #if defined(__ANDROID__) #include diff --git a/mindspore-lite/src/litert/lite_kernel.h b/mindspore-lite/src/litert/lite_kernel.h index 6b58b59d..501ab5d1 100644 --- a/mindspore-lite/src/litert/lite_kernel.h +++ b/mindspore-lite/src/litert/lite_kernel.h @@ -23,7 +23,7 @@ #include #include "src/common/utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/litert/inner_context.h" #include "src/tensor.h" #include "include/errorcode.h" diff --git a/mindspore-lite/src/litert/lite_model.h b/mindspore-lite/src/litert/lite_model.h index 2e62655c..24a4da0e 100644 --- a/mindspore-lite/src/litert/lite_model.h +++ b/mindspore-lite/src/litert/lite_model.h @@ -27,7 +27,7 @@ #include "src/common/log_adapter.h" #include "src/common/version_manager.h" #include "src/litert/schema_tensor_wrapper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/prim_util.h" #ifdef ENABLE_MODEL_OBF #include "tools/obfuscator/include/deobfuscator.h" diff --git a/mindspore-lite/src/litert/mindrt_executor.cc b/mindspore-lite/src/litert/mindrt_executor.cc index 7162ea69..16c45476 100644 --- a/mindspore-lite/src/litert/mindrt_executor.cc +++ b/mindspore-lite/src/litert/mindrt_executor.cc @@ -23,9 +23,9 @@ #include "src/common/common.h" #include "src/common/tensor_util.h" #ifdef ENABLE_FP16 -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" #endif -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #include "src/litert/kernel_exec_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc index c18ae7f2..0be0ad79 100644 --- a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc +++ b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc @@ -19,7 +19,7 @@ #include "src/litert/pass/format_pass/eliminate_transpose.h" #ifdef ENABLE_MULTI_LAYOUT #include "src/litert/kernel_registry.h" -#include "nnacl/format_transpose_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" #endif #include "src/common/draw/drawer.h" diff --git a/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc b/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc index b5543f76..9b71774b 100644 --- a/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc +++ b/mindspore-lite/src/litert/pass/format_pass/insert_transpose.cc @@ -17,7 +17,7 @@ #include "src/litert/pass/format_pass/insert_transpose.h" #include "src/litert/pass/format_pass/format_utils.h" #include "src/litert/kernel_exec_util.h" -#include "nnacl/base/format_transpose.h" +#include "nnacl_c/base/format_transpose.h" namespace mindspore::lite::pass { int InsertTranspose::TransposeConstData(kernel::KernelExec *kernel, size_t index) { diff --git a/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc b/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc index 7ba83911..c9067433 100644 --- a/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc +++ b/mindspore-lite/src/litert/pass/format_pass/pass_utils.cc @@ -17,8 +17,8 @@ #include "src/litert/pass/format_pass/pass_utils.h" #include #include -#include "nnacl/format_transpose_parameter.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/format_transpose_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" namespace mindspore::lite::pass { bool IsNoneTranspose(const TransInfoPair &trans) { diff --git a/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc b/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc index c68cc9d1..15ed7a3e 100644 --- a/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc +++ b/mindspore-lite/src/litert/pass/format_pass/transpose_strategy.cc @@ -15,18 +15,18 @@ */ #include "src/litert/pass/format_pass/transpose_strategy.h" -#include "nnacl/op_base.h" -#include "nnacl/arg_min_max_parameter.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/crop_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/squeeze_parameter.h" -#include "nnacl/stack_parameter.h" -#include "nnacl/unsqueeze_parameter.h" -#include "nnacl/unstack_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/arg_min_max_parameter.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/crop_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/squeeze_parameter.h" +#include "nnacl_c/stack_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" +#include "nnacl_c/unstack_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore::lite::pass { static const std::set arithmetic_kernel_lists = { diff --git a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc index 2c29023f..773d1c23 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc @@ -18,7 +18,7 @@ #include #include "src/litert/pass/online_fusion/online_fusion_utils.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc index c5f802b2..7c458f39 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.cc @@ -17,9 +17,9 @@ #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "include/model.h" namespace mindspore::lite { diff --git a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h index 7ea6fe7f..d9574ee5 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/online_fusion_pass.h @@ -29,8 +29,8 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass_registry.h" #include "src/common/prim_util.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class OnlineFusionPass { diff --git a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc index 43153432..2ff41458 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc @@ -18,8 +18,8 @@ #include #include "src/litert/pass/online_fusion/online_fusion_utils.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/reduce_parameter.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/reduce_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h index 52d8311e..be313205 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.h @@ -29,7 +29,7 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include "src/common/prim_util.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class ReduceConcatOnlineFusionPass : public OnlineFusionPass { diff --git a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc index 18a835e6..76c44638 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc @@ -17,9 +17,9 @@ #include "src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h" #include #include "src/common/ops/populate/populate_register.h" -#include "nnacl/split_parameter.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/reduce_parameter.h" #include "include/model.h" namespace { diff --git a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h index aca4c830..4b65e5ba 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h +++ b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.h @@ -29,7 +29,7 @@ #include "src/litert/sub_graph_split.h" #include "src/litert/pass/online_fusion/online_fusion_pass.h" #include "src/common/prim_util.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite { class SplitReduceConcatOnlineFusionPass : public OnlineFusionPass { diff --git a/mindspore-lite/src/litert/runtime_packed_node_pass.cc b/mindspore-lite/src/litert/runtime_packed_node_pass.cc index ed7f54b9..85fdb395 100644 --- a/mindspore-lite/src/litert/runtime_packed_node_pass.cc +++ b/mindspore-lite/src/litert/runtime_packed_node_pass.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include "src/litert/runtime_packed_node_pass.h" -#include "nnacl/op_base.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/matmul_parameter.h" #include "nnacl/nnacl_kernel.h" -#include "nnacl/kernel/matmul_struct.h" +#include "nnacl_c/kernel/matmul_struct.h" #include "common/string_utils.h" using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool); diff --git a/mindspore-lite/src/litert/runtime_pass.cc b/mindspore-lite/src/litert/runtime_pass.cc index da92cca5..a98cbbaf 100644 --- a/mindspore-lite/src/litert/runtime_pass.cc +++ b/mindspore-lite/src/litert/runtime_pass.cc @@ -15,7 +15,7 @@ */ #include "src/litert/runtime_pass.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { #ifndef RUNTIME_PASS_CLIP diff --git a/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc b/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc index 79b10127..726138b4 100644 --- a/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc +++ b/mindspore-lite/src/litert/runtime_shape_fusion_pass.cc @@ -21,7 +21,7 @@ #include #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { diff --git a/mindspore-lite/src/litert/scheduler.cc b/mindspore-lite/src/litert/scheduler.cc index 440ec0d2..2e9a6844 100644 --- a/mindspore-lite/src/litert/scheduler.cc +++ b/mindspore-lite/src/litert/scheduler.cc @@ -22,7 +22,7 @@ #include #include #include "src/tensorlist.h" -#include "nnacl/partial_fusion_parameter.h" +#include "nnacl_c/partial_fusion_parameter.h" #include "include/errorcode.h" #include "src/common/graph_util.h" #include "src/common/utils.h" @@ -47,7 +47,7 @@ #endif #include "src/litert/weight_decoder.h" #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #if GPU_OPENCL #include "src/litert/kernel/opencl/opencl_subgraph.h" #include "src/litert/kernel/gpu/opencl/opencl_runtime.h" diff --git a/mindspore-lite/src/litert/schema_tensor_wrapper.cc b/mindspore-lite/src/litert/schema_tensor_wrapper.cc index 5ba0a55a..bc608f18 100644 --- a/mindspore-lite/src/litert/schema_tensor_wrapper.cc +++ b/mindspore-lite/src/litert/schema_tensor_wrapper.cc @@ -17,7 +17,7 @@ #include "src/litert/schema_tensor_wrapper.h" #include "src/common/log_adapter.h" #include "src/common/file_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/litert/sub_graph_split.cc b/mindspore-lite/src/litert/sub_graph_split.cc index b1655c8c..ef5f0798 100644 --- a/mindspore-lite/src/litert/sub_graph_split.cc +++ b/mindspore-lite/src/litert/sub_graph_split.cc @@ -27,9 +27,9 @@ #include "src/common/ops/populate/populate_register.h" #include "src/litert/scheduler.h" #include "src/litert/tensor_category.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" #include "include/model.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/base/conv_common_base.h" namespace { constexpr const int kMaxDepth = 2048; diff --git a/mindspore-lite/src/litert/sub_graph_split.h b/mindspore-lite/src/litert/sub_graph_split.h index 2b16aedb..0bb7f08e 100644 --- a/mindspore-lite/src/litert/sub_graph_split.h +++ b/mindspore-lite/src/litert/sub_graph_split.h @@ -27,7 +27,7 @@ #include "src/litert/lite_model.h" #include "src/litert/inner_context.h" #include "src/common/prim_util.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { constexpr int kDefaultSubGraphSize = 2; diff --git a/mindspore-lite/src/litert/thread_cost_model.cc b/mindspore-lite/src/litert/thread_cost_model.cc index 5bbd3d36..d3e6f5c0 100644 --- a/mindspore-lite/src/litert/thread_cost_model.cc +++ b/mindspore-lite/src/litert/thread_cost_model.cc @@ -19,7 +19,7 @@ #include "src/common/log_util.h" #include "src/litert/inner_context.h" #include "thread/threadpool.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { const std::map kernel_compute_cost_map_ = { diff --git a/mindspore-lite/src/litert/thread_cost_model.h b/mindspore-lite/src/litert/thread_cost_model.h index 70c9ca9d..0254b1ef 100644 --- a/mindspore-lite/src/litert/thread_cost_model.h +++ b/mindspore-lite/src/litert/thread_cost_model.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_THREAD_COST_MODEL_H_ #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/api/context.h" #include "schema/ops_generated.h" diff --git a/mindspore-lite/src/litert/weight_decoder.cc b/mindspore-lite/src/litert/weight_decoder.cc index aa6ca73b..5e21960f 100644 --- a/mindspore-lite/src/litert/weight_decoder.cc +++ b/mindspore-lite/src/litert/weight_decoder.cc @@ -18,7 +18,7 @@ #include "src/litert/weight_decoder.h" #include "src/litert/huffman_decode.h" #include "tools/converter/quantizer/fse_decoder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite { #ifndef WEIGHT_DECODE_CLIP diff --git a/mindspore-lite/src/litert/weight_decoder.h b/mindspore-lite/src/litert/weight_decoder.h index 4df0eb82..371b3eea 100644 --- a/mindspore-lite/src/litert/weight_decoder.h +++ b/mindspore-lite/src/litert/weight_decoder.h @@ -24,8 +24,8 @@ #include #include #include -#include "nnacl/matmul_parameter.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "src/executor/kernel_exec.h" #include "src/common/utils.h" #include "src/tensor.h" diff --git a/mindspore-lite/src/tensor.h b/mindspore-lite/src/tensor.h index a02a51b9..724674db 100644 --- a/mindspore-lite/src/tensor.h +++ b/mindspore-lite/src/tensor.h @@ -26,8 +26,8 @@ #include #include "include/api/format.h" #include "include/lite_types.h" -#include "nnacl/tensor_c.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensor_c.h" +#include "nnacl_c/tensor_c_utils.h" #include "src/litert/inner_allocator.h" #include "src/common/log_adapter.h" #include "src/common/utils.h" diff --git a/mindspore-lite/src/tensorlist.cc b/mindspore-lite/src/tensorlist.cc index 13fe91e3..10d87438 100644 --- a/mindspore-lite/src/tensorlist.cc +++ b/mindspore-lite/src/tensorlist.cc @@ -19,7 +19,7 @@ #include #include "src/common/log_adapter.h" #include "src/tensor.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { #ifndef CONTROLFLOW_TENSORLIST_CLIP diff --git a/mindspore-lite/src/tensorlist.h b/mindspore-lite/src/tensorlist.h index 5c85e29d..406dd5fa 100644 --- a/mindspore-lite/src/tensorlist.h +++ b/mindspore-lite/src/tensorlist.h @@ -20,7 +20,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/tensorlist_c.h" +#include "nnacl_c/tensorlist_c.h" #include "src/common/log_adapter.h" #include "schema/model_generated.h" #include "src/tensor.h" diff --git a/mindspore-lite/src/train/opt_allocator.cc b/mindspore-lite/src/train/opt_allocator.cc index d9931641..485e4b97 100644 --- a/mindspore-lite/src/train/opt_allocator.cc +++ b/mindspore-lite/src/train/opt_allocator.cc @@ -15,7 +15,7 @@ */ #include "src/train/opt_allocator.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { size_t OptAllocator::FindFree(size_t size) { diff --git a/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc index 435686e5..91655293 100644 --- a/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc +++ b/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc @@ -24,7 +24,7 @@ #include #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/train/train_loop.cc b/mindspore-lite/src/train/train_loop.cc index b565443b..125c4ede 100644 --- a/mindspore-lite/src/train/train_loop.cc +++ b/mindspore-lite/src/train/train_loop.cc @@ -22,7 +22,7 @@ #include "include/errorcode.h" #include "include/dataset/iterator.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/src/train/train_populate_parameter.cc b/mindspore-lite/src/train/train_populate_parameter.cc index c1a70b92..3cfdd24d 100644 --- a/mindspore-lite/src/train/train_populate_parameter.cc +++ b/mindspore-lite/src/train/train_populate_parameter.cc @@ -17,21 +17,21 @@ #include "include/securec.h" #include "src/common/ops/populate/populate_register.h" #include "src/common/ops/populate/default_populate.h" -#include "nnacl/strided_slice_parameter.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/pow_parameter.h" -#include "nnacl/activation_parameter.h" -#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h" -#include "nnacl/fp32_grad/optimizer.h" -#include "nnacl/fp32_grad/batch_norm_parameter.h" -#include "nnacl/fp32_grad/dropout_parameter.h" -#include "nnacl/fp32_grad/smooth_l1_loss.h" -#include "nnacl/fp32_grad/resize_grad_parameter.h" -#include "nnacl/fp32_grad/lstm_grad_fp32.h" -#include "nnacl/fp32_grad/binary_cross_entropy.h" -#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/pow_parameter.h" +#include "nnacl_c/activation_parameter.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" +#include "nnacl_c/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/batch_norm_parameter.h" +#include "nnacl_c/fp32_grad/dropout_parameter.h" +#include "nnacl_c/fp32_grad/smooth_l1_loss.h" +#include "nnacl_c/fp32_grad/resize_grad_parameter.h" +#include "nnacl_c/fp32_grad/lstm_grad_fp32.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy.h" +#include "nnacl_c/fp32_grad/binary_cross_entropy_grad.h" using mindspore::lite::Registry; diff --git a/mindspore-lite/test/CMakeLists.txt b/mindspore-lite/test/CMakeLists.txt index 0d076086..f44d75b9 100644 --- a/mindspore-lite/test/CMakeLists.txt +++ b/mindspore-lite/test/CMakeLists.txt @@ -4,7 +4,7 @@ set(LITE_DIR ${TOP_DIR}/mindspore-lite) include_directories(${TOP_DIR}) include_directories(${TEST_DIR}) -include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/../) include(${TOP_DIR}/cmake/external_libs/gtest.cmake) include(${TOP_DIR}/cmake/external_libs/mockcpp.cmake) diff --git a/mindspore-lite/test/common/common_test.h b/mindspore-lite/test/common/common_test.h index aef9e7d9..e8f32d3b 100644 --- a/mindspore-lite/test/common/common_test.h +++ b/mindspore-lite/test/common/common_test.h @@ -25,8 +25,8 @@ #include "gtest/gtest.h" #include "include/api/format.h" #include "src/litert/tensor_category.h" -#include "nnacl/tensorlist_c_utils.h" -#include "nnacl/tensor_c_utils.h" +#include "nnacl_c/tensorlist_c_utils.h" +#include "nnacl_c/tensor_c_utils.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc index 31834b01..355bd16e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/adam_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/adam_infer.h" +#include "nnacl_c/infer/adam_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc index 52dae7fb..62f8654c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/adam_weight_decay_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/adam_weight_decay_infer.h" +#include "nnacl_c/infer/adam_weight_decay_infer.h" namespace mindspore { class AdamWeightDecayInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc index 0604cfa8..6d5cd2c8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/addn_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/addn_infer.h" +#include "nnacl_c/infer/addn_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc index 5c21e823..ee3b919c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/apply_momentum_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/apply_momentum_infer.h" +#include "nnacl_c/infer/apply_momentum_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc index d003c9d2..10ecad2c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/argmax_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc index 7fe54cfd..1c6ed96a 100644 --- a/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/argmin_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl_c/infer/argmin_max_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc index 869edf1d..c1096310 100644 --- a/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/arithmetic_compare_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/arithmetic_compare_infer.h" +#include "nnacl_c/infer/arithmetic_compare_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc index 146c3f0f..fa759509 100644 --- a/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/arithmetic_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/arithmetic_infer.h" +#include "nnacl_c/infer/arithmetic_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc index 6f6e823f..b2055260 100644 --- a/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/assign_add_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/assign_add_infer.h" +#include "nnacl_c/infer/assign_add_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc index 8263392c..739c25be 100644 --- a/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/assign_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/assign_infer.h" +#include "nnacl_c/infer/assign_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc index 5c724ae1..14497e73 100644 --- a/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/audio_spectrogram_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl_c/infer/audio_spectrogram_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc index 0541102b..2985f2b7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/batch_to_space_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/batch_to_space_infer.h" +#include "nnacl_c/infer/batch_to_space_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc index 55c5d2de..855e0a4e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/bias_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/bias_grad_infer.h" +#include "nnacl_c/infer/bias_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc index b6a7e3f8..25c51b06 100644 --- a/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/binary_cross_entropy_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/binary_cross_entropy_infer.h" +#include "nnacl_c/infer/binary_cross_entropy_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc index 8fbabbca..9ea08ec6 100644 --- a/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/bn_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/bn_grad_infer.h" +#include "nnacl_c/infer/bn_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc index 31e2b55c..34f44a04 100644 --- a/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/broadcast_to_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/broadcast_to_infer.h" +#include "nnacl_c/infer/broadcast_to_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc index 3c490677..18b54fce 100644 --- a/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/cast_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/cast_infer.h" +#include "nnacl_c/infer/cast_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc index 03624cfd..a3b15dd2 100644 --- a/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/concat_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/concat_infer.h" +#include "nnacl_c/infer/concat_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc index 792ea55b..7f4cbd53 100644 --- a/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/constant_of_shape_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/constant_of_shape_infer.h" +#include "nnacl_c/infer/constant_of_shape_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc index 868768e8..1d22281e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_filter_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_grad_filter_infer.h" +#include "nnacl_c/infer/conv2d_grad_filter_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc index 12e0fbb1..cce2634d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_grad_input_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_grad_input_infer.h" +#include "nnacl_c/infer/conv2d_grad_input_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc index 749cddb2..3f88a643 100644 --- a/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/conv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/infer/conv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc index 01cfd8a1..5bee3402 100644 --- a/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/crop_and_resize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/crop_and_resize_infer.h" +#include "nnacl_c/infer/crop_and_resize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc index ca1798d1..434d97a1 100644 --- a/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/crop_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/crop_infer.h" +#include "nnacl_c/infer/crop_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc index e66522f0..5a09a796 100644 --- a/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/cumsum_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/cumsum_infer.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/infer/cumsum_infer.h" +#include "nnacl_c/cumsum_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc index cd1ded3f..c38148aa 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_extract_features_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_extract_features_infer.h" +#include "nnacl_c/infer/string/custom_extract_features_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc index 1c84fdd7..081aa543 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_normalize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_normalize_infer.h" +#include "nnacl_c/infer/string/custom_normalize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc index b908aa7a..415c5c88 100644 --- a/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/custom_predict_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/custom_predict_infer.h" +#include "nnacl_c/infer/string/custom_predict_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc index aa2360f0..3e1f9579 100644 --- a/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/deconv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/deconv2d_infer.h" +#include "nnacl_c/infer/deconv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc index 7b947e4b..2508e212 100644 --- a/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/depth_to_space_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/depth_to_space_infer.h" +#include "nnacl_c/infer/depth_to_space_infer.h" #include "src/tensor.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc index 74e9d27e..cf96e455 100644 --- a/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/depthwise_conv2d_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/depthwise_conv2d_infer.h" +#include "nnacl_c/infer/depthwise_conv2d_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc index 9dcca21d..0321ff0c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/detection_post_process_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/detection_post_process_infer.h" +#include "nnacl_c/infer/detection_post_process_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc index 758205e4..9731d71e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/dropout_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/dropout_grad_infer.h" +#include "nnacl_c/infer/dropout_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc index 98b8cf24..995b129f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/embedding_lookup_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/embedding_lookup_infer.h" +#include "nnacl_c/infer/embedding_lookup_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc index 92c099f8..039ae90a 100644 --- a/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/expand_dims_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/expand_dims_infer.h" +#include "nnacl_c/infer/expand_dims_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc index 30730831..36390924 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fft_imag_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fft_imag_infer.h" +#include "nnacl_c/infer/fft_imag_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc index 2ad4b009..451462af 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fill_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fill_infer.h" -#include "nnacl/fill_parameter.h" +#include "nnacl_c/infer/fill_infer.h" +#include "nnacl_c/fill_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc index 686fdf22..acad0586 100644 --- a/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/flatten_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/flatten_grad_infer.h" +#include "nnacl_c/infer/flatten_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc index f8152793..24029984 100644 --- a/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/flatten_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/flatten_infer.h" +#include "nnacl_c/infer/flatten_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc index 0dd33a6a..33f14746 100644 --- a/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/full_connection_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/full_connection_infer.h" +#include "nnacl_c/infer/full_connection_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc index 5df7e3e2..b5bdd027 100644 --- a/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/fused_batchnorm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/fused_batchnorm_infer.h" +#include "nnacl_c/infer/fused_batchnorm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc index 079c9c3c..5d35f7a8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gather_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gather_infer.h" +#include "nnacl_c/infer/gather_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc index 26fd731c..5c203271 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gather_nd_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gather_nd_infer.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/infer/gather_nd_infer.h" +#include "nnacl_c/gather_nd_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc index f687aa7d..d2b9b3a7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/group_conv2d_grad_input_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/group_conv2d_grad_input_infer.h" +#include "nnacl_c/infer/group_conv2d_grad_input_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc index 6f917cf1..dde9f8a7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/gru_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/gru_infer.h" +#include "nnacl_c/infer/gru_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc index 4768bedf..37eb8114 100644 --- a/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/hashtable_lookup_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/hashtable_lookup_infer.h" +#include "nnacl_c/infer/string/hashtable_lookup_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc index 78b5666f..07467648 100644 --- a/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/invert_permutation_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/invert_permutation_infer.h" +#include "nnacl_c/infer/invert_permutation_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc index 167f8a03..68498187 100644 --- a/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/layer_norm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/layer_norm_infer.h" +#include "nnacl_c/infer/layer_norm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc index 33717760..72a19e04 100644 --- a/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/lsh_projection_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/lsh_projection_infer.h" +#include "nnacl_c/infer/string/lsh_projection_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc index ef24be40..5c3434eb 100644 --- a/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/lstm_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/lstm_infer.h" +#include "nnacl_c/infer/lstm_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc index 1c2c807d..e5e966dc 100644 --- a/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/matmul_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/matmul_infer.h" +#include "nnacl_c/infer/matmul_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc index 107bcbf3..740a7278 100644 --- a/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/max_min_grad_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/max_min_grad_infer.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/infer/max_min_grad_infer.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc index f6c2a615..7efce63b 100644 --- a/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/mfcc_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/mfcc_infer.h" +#include "nnacl_c/infer/mfcc_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc index 260dcfa7..0c0480b8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/nllloss_grad_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/nllloss_grad_infer.h" +#include "nnacl_c/infer/nllloss_grad_infer.h" namespace mindspore { class TestNLLLossGradInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc index 7de3cc4f..416d2299 100644 --- a/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/nllloss_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" -#include "nnacl/infer/nllloss_infer.h" +#include "nnacl_c/infer/nllloss_infer.h" namespace mindspore { class TestNLLLossInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc index 03243518..0752eafd 100644 --- a/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/one_hot_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/one_hot_infer.h" +#include "nnacl_c/infer/one_hot_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc index 2c6473e1..b66b40a5 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pad_infer.h" +#include "nnacl_c/infer/pad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc index 74c32a2f..bb776ea3 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pooling_grad_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pooling_grad_infer.h" +#include "nnacl_c/infer/pooling_grad_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc index 0fb27892..9f4d72d5 100644 --- a/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/pooling_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/pooling_infer.h" +#include "nnacl_c/infer/pooling_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc index dc7b10f7..0047b71d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/power_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/power_infer.h" +#include "nnacl_c/infer/power_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc index 9d60d2f4..7c99fb15 100644 --- a/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/quant_dtype_cast_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/quant_dtype_cast_infer.h" +#include "nnacl_c/infer/quant_dtype_cast_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc index ab81dcde..c1c6e56e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/random_standard_normal_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/random_standard_normal_infer.h" +#include "nnacl_c/infer/random_standard_normal_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc index 067490b7..8f40763d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/range_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/range_infer.h" -#include "nnacl/range_parameter.h" +#include "nnacl_c/infer/range_infer.h" +#include "nnacl_c/range_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc index 8b6000c6..c2c99e47 100644 --- a/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/rank_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/rank_infer.h" +#include "nnacl_c/infer/rank_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc index a1823f74..a5bf6ed7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/reduce_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/reduce_infer.h" +#include "nnacl_c/infer/reduce_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc index ade7305d..47b4be80 100644 --- a/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/reshape_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/reshape_infer.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/infer/reshape_infer.h" +#include "nnacl_c/reshape_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc index 5ac041cd..4aaf0a75 100644 --- a/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/resize_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/resize_infer.h" +#include "nnacl_c/infer/resize_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc index 20a33e82..ce630250 100644 --- a/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/rfft_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/rfft_infer.h" +#include "nnacl_c/infer/rfft_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc index 25293239..ee8cc2a7 100644 --- a/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/roi_pooling_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/roi_pooling_infer.h" +#include "nnacl_c/infer/roi_pooling_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc index 25ca4491..5d3813c8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_add_infer_test.cc @@ -15,8 +15,8 @@ */ #include "common/common_test.h" -#include "nnacl/scatter_nd_parameter.h" -#include "nnacl/infer/scatter_nd_update_infer.h" +#include "nnacl_c/scatter_nd_parameter.h" +#include "nnacl_c/infer/scatter_nd_update_infer.h" namespace mindspore { class TestScatterNdAddInfer : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc index 70efa6d3..3600d37d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/scatter_nd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/scatter_nd_infer.h" +#include "nnacl_c/infer/scatter_nd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc index f6d308f9..67f74a02 100644 --- a/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/select_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/select_infer.h" +#include "nnacl_c/infer/select_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc index 72c60a45..340bc8bf 100644 --- a/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/sgd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/sgd_infer.h" +#include "nnacl_c/infer/sgd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc index bbc10b3a..920f589d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/shape_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/shape_infer.h" +#include "nnacl_c/infer/shape_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc index f6e9b7d4..2325f9f8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/size_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/size_infer.h" +#include "nnacl_c/infer/size_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc index ef7adebb..9669e060 100644 --- a/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/skip_gram_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/string/skip_gram_infer.h" +#include "nnacl_c/infer/string/skip_gram_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc index 33fdcd74..087afc9c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/slice_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/slice_infer.h" +#include "nnacl_c/infer/slice_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc index 27a7c96c..b1cb2170 100644 --- a/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/softmax_cross_entropy_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/softmax_cross_entropy_infer.h" +#include "nnacl_c/infer/softmax_cross_entropy_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc index 83584587..0b440079 100644 --- a/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/softmax_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/softmax_infer.h" +#include "nnacl_c/infer/softmax_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc index c780c354..4786e258 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_batch_infer.h" +#include "nnacl_c/infer/space_to_batch_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc index 224d3edf..0a160e43 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_batch_nd_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_batch_nd_infer.h" +#include "nnacl_c/infer/space_to_batch_nd_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc index 4293d921..ea69414f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/space_to_depth_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/space_to_depth_infer.h" +#include "nnacl_c/infer/space_to_depth_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc index a04d0d03..76bf4f90 100644 --- a/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/sparse_to_dense_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/sparse_to_dense_infer.h" +#include "nnacl_c/infer/sparse_to_dense_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc index 64242b0a..f755188d 100644 --- a/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/split_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/split_infer.h" +#include "nnacl_c/infer/split_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc index 3873efc6..5b181563 100644 --- a/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/squeeze_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/squeeze_infer.h" +#include "nnacl_c/infer/squeeze_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc index 7cddaab6..65435538 100644 --- a/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/stack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/stack_infer.h" +#include "nnacl_c/infer/stack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc index 37fbd4e5..6a4ab89c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/strided_slice_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/strided_slice_infer.h" +#include "nnacl_c/infer/strided_slice_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc index 481690f8..eafde941 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_fromtensor_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl_c/infer/control/tensorlist_fromtensor_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc index 4f39c093..2a235823 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_getitem_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_getitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_getitem_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc index abefbeea..fe826238 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_reserve_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_reserve_infer.h" +#include "nnacl_c/infer/control/tensorlist_reserve_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc index f726e159..81026926 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_setitem_infer_test.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/common/tensor_util.h" -#include "nnacl/infer/control/tensorlist_setitem_infer.h" +#include "nnacl_c/infer/control/tensorlist_setitem_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc index 42a7adbf..8095670e 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tensorlist_stack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/control/tensorlist_stack_infer.h" +#include "nnacl_c/infer/control/tensorlist_stack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc index 705c127e..316c9f3f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/tile_infer_test.cc @@ -14,9 +14,9 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/tile_infer.h" -#include "nnacl/base/tile_base.h" -#include "nnacl/tile_parameter.h" +#include "nnacl_c/infer/tile_infer.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/tile_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc index d467ae27..55e98524 100644 --- a/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/topk_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/topk_infer.h" +#include "nnacl_c/infer/topk_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc index 5d300eb8..4096d1cd 100644 --- a/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/transpose_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/transpose_infer.h" +#include "nnacl_c/infer/transpose_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc index 0facfd6c..534921c9 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unique_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unique_infer.h" +#include "nnacl_c/infer/unique_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc index 9094cc25..4f14be0f 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unsorted_segment_sum_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unsorted_segment_sum_infer.h" +#include "nnacl_c/infer/unsorted_segment_sum_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc index 98ad4b3f..bcf3bd19 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unsqueeze_infer_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unsqueeze_infer.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/infer/unsqueeze_infer.h" +#include "nnacl_c/unsqueeze_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc index ee31cae6..c5bcb70c 100644 --- a/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/unstack_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/unstack_infer.h" +#include "nnacl_c/infer/unstack_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc b/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc index 39bfe729..f24572b8 100644 --- a/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc +++ b/mindspore-lite/test/ut/nnacl/infer/where_infer_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/infer/where_infer.h" +#include "nnacl_c/infer/where_infer.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc b/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc index d8df2ff1..18c2074e 100644 --- a/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc +++ b/mindspore-lite/test/ut/nnacl/int8/quant_dtype_cast_int8_test.cc @@ -16,9 +16,9 @@ #include #include #include "common/common_test.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #ifdef ENABLE_ARM64 -#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl_c/fp16/quant_dtype_cast_fp16.h" #endif namespace mindspore { diff --git a/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc b/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc index 63f95d22..ef8db4b3 100644 --- a/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc +++ b/mindspore-lite/test/ut/nnacl/kernel/cast_test.cc @@ -16,9 +16,9 @@ #include #include #include "common/common_test.h" -#include "nnacl/op_base.h" -#include "nnacl/base/cast_base.h" -#include "nnacl/kernel/cast.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/kernel/cast.h" namespace mindspore { class CastTest : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc index 5642f835..9c2fd933 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/pack_tests.cc @@ -19,11 +19,11 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/pack.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/conv_common_fp32.h" #ifdef ENABLE_FP16 -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp16/conv_fp16.h" +#include "nnacl_c/fp16/pack_fp16.h" +#include "nnacl_c/fp16/conv_fp16.h" #endif namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc index 971e0c4d..c51aa18c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc index 2b671b7c..75f20b16 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/activation_grad_fp16_test.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp16_grad/activation_grad_fp16.h" +#include "nnacl_c/fp16_grad/activation_grad_fp16.h" namespace mindspore { class TestActGradFp16 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc index add5c8ec..3d5a767d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp16_grad/arithmetic_fp16_self_grad_tests.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp16_grad/arithmetic_self_grad.h" +#include "nnacl_c/fp16_grad/arithmetic_self_grad.h" namespace mindspore { class TestArithmeticSelfGradFp16 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc index 6344e560..a19d152c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32-sparsity/matmul_fp32_tests.cc @@ -19,7 +19,7 @@ #include #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/tensor.h" #include "include/securec.h" #include "src/litert/infer_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc index 9f7dc50f..2cab2292 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/executor/kernel_exec.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc index edb2dde9..a89a7338 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batch_to_space_fp32_test.cc @@ -15,9 +15,9 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/base/batch_to_space_base.h" -#include "nnacl/batch_to_space_parameter.h" -#include "nnacl/common_func.h" +#include "nnacl_c/base/batch_to_space_base.h" +#include "nnacl_c/batch_to_space_parameter.h" +#include "nnacl_c/common_func.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc index 03995b9c..2943341b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index 019d5ffa..6c6420df 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/fp32/convolution_1x1_fp32.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc index 7f28e655..83d510c4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/crop_fp32_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/fp32/crop_fp32.h" +#include "nnacl_c/fp32/crop_fp32.h" #include "src/litert/tensor_category.h" #include "src/litert/lite_kernel.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc index 2e483c0a..0fa93471 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/cumsum_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/cumsum_parameter.h" +#include "nnacl_c/cumsum_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index 39dcc6ef..16a9c598 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -18,8 +18,8 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/op_base.h" #include "src/litert/tensor_category.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc index 08497018..901dcd06 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/depth_to_space_fp32_test.cc @@ -15,10 +15,10 @@ */ #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/base/depth_to_space_base.h" -#include "nnacl/common_func.h" -#include "nnacl/depth_to_space_parameter.h" -#include "nnacl/kernel/depth_to_space.h" +#include "nnacl_c/base/depth_to_space_base.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/depth_to_space_parameter.h" +#include "nnacl_c/kernel/depth_to_space.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc index f665feb9..92fc83ca 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/litert/kernel/cpu/fp32/embedding_lookup_fp32.h" -#include "nnacl/fp32/embedding_lookup_fp32.h" +#include "nnacl_c/fp32/embedding_lookup_fp32.h" #include "src/common/file_utils.h" #include "common/common_test.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc index 29d0ed36..1c9308f5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/common/file_utils.h" #include "src/litert/tensor_category.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc index 68cfdbc3..58cbcc24 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/logicalor_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc index 687562f0..deca3585 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lsh_projection_fp32_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/lsh_projection_parameter.h" +#include "nnacl_c/lsh_projection_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc index 9972cf5e..9441e03b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl_c/fp32/lstm_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc index 5f205625..b7dcc712 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -16,8 +16,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc index 3c3b47c8..6fdbba8f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" -#include "nnacl/nllloss_parameter.h" +#include "nnacl_c/nllloss_parameter.h" #include "nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc index c23773fe..cb9ec5c8 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/one_hot_fp32_test.cc @@ -17,7 +17,7 @@ #include "src/executor/kernel_exec.h" #include "src/tensor.h" #include "common/common_test.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" #include "src/litert/kernel_registry.h" #include "schema/ops_generated.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc index 8d936215..7a942a14 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -19,7 +19,7 @@ #include "src/executor/kernel_exec.h" #include "src/litert/tensor_category.h" #include "nnacl/nnacl_manager.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore { class TestPowerFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc index aac8dd24..2d536075 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/ragged_range_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/ragged_range_fp32.h" +#include "nnacl_c/fp32/ragged_range_fp32.h" #include "src/tensor.h" #include "src/executor/kernel_exec.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc index 7db8176c..ca6857e4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reduce_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" #include "schema/inner/model_generated.h" #include "src/tensor.h" #include "nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc index 7edc4e2d..0250d9fb 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_bilinear_fp32_tests.cc @@ -18,7 +18,7 @@ #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "schema/ops_generated.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc index 01acf479..800a586e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/resize_nearest_neighbor_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc index 4f738dc9..0a1f041a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/reverse_sequence_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/reverse_sequence_fp32.h" +#include "nnacl_c/fp32/reverse_sequence_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc index 785ca827..5b16bb68 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scale_fp32_tests.cc @@ -17,10 +17,10 @@ #include "src/executor/kernel_exec.h" #include "src/tensor.h" #include "common/common_test.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "schema/ops_generated.h" -#include "nnacl/fp32/scale_fp32.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/scale_parameter.h" #include "nnacl/nnacl_manager.h" using mindspore::schema::ActivationType; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc index ef8a7b11..47cf6f3e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_add_fp32_test.cc @@ -16,7 +16,7 @@ #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" namespace mindspore { using mindspore::lite::Tensor; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc index 117950c1..b171b5b5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/scatter_nd_fp32_tests.cc @@ -15,7 +15,7 @@ */ #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/scatter_nd_parameter.h" +#include "nnacl_c/scatter_nd_parameter.h" namespace mindspore { using mindspore::lite::Tensor; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc index 0f079ef5..aaa12c6e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/skip_gram_fp32.cc @@ -16,7 +16,7 @@ #include #include "src/litert/kernel/cpu/string/skip_gram.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/file_utils.h" #include "src/litert/tensor_category.h" #include "common/common_test.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc index 84aee370..57a287d4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/softmax_tests.cc @@ -16,7 +16,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel/cpu/nnacl/nnacl_manager.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc index 20f49eab..88e177b0 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc index 4ee7d255..1ebbe4f8 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc @@ -18,8 +18,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/space_to_depth_parameter.h" -#include "nnacl/base/space_to_depth_base.h" +#include "nnacl_c/space_to_depth_parameter.h" +#include "nnacl_c/base/space_to_depth_base.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc index d390a003..08704f99 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/sparse_to_dense_fp32_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl_c/fp32/sparse_to_dense_fp32.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc index c454945b..758e7f3f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "nnacl/base/stack_base.h" +#include "nnacl_c/base/stack_base.h" namespace mindspore { class StackTestFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc index 44e538cf..fde4a2bf 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/tile_fp32_tests.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/tile_parameter.h" +#include "nnacl_c/tile_parameter.h" #include "src/litert/kernel_registry.h" #include "src/litert/kernel/cpu/nnacl/nnacl_manager.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc index be71e7c8..f6852855 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc index 1978d45c..87210921 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc @@ -18,8 +18,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" #include "nnacl/nnacl_manager.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc index 4a2989bf..be300dea 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/uniform_real_fp32_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "nnacl/random_parameter.h" +#include "nnacl_c/random_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc index c5c18492..047d16ac 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unique_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/fp32/unique_fp32.h" +#include "nnacl_c/fp32/unique_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc index 09d61569..4b57c33c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/unstack_fp32_tests.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "nnacl/base/unstack_base.h" +#include "nnacl_c/base/unstack_base.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc index fb7c90c1..1d0ef614 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/activation_grad_fp32_tests.cc @@ -25,7 +25,7 @@ #include "src/tensor.h" #include "src/executor/kernel_exec.h" #include "src/litert/kernel/cpu/fp32_grad/activation_grad.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" namespace mindspore { class TestActGradFp32 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc index 144dd5c1..72265769 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/arithmetic_grad_fp32_tests.cc @@ -21,7 +21,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" #include "src/litert/kernel/cpu/fp32_grad/arithmetic_grad.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc index 16d44cfa..642f5e20 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -19,10 +19,10 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/bn_grad.h" -#include "nnacl/fp32_grad/batch_norm_grad.h" -#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl_c/fp32_grad/batch_norm_grad.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" #include "src/litert/kernel_registry.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore { constexpr int kSize3 = 3; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index dbc4dd0e..669032a9 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -23,7 +23,7 @@ #include "src/litert/kernel/cpu/fp32_grad/convolution.h" #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.h" #include "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc index f5446f12..cd1ae7e5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc @@ -20,7 +20,7 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index c6eda84e..473caf06 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -20,8 +20,8 @@ #include "common/common_test.h" #include "src/common/utils.h" #include "src/common/file_utils.h" -#include "nnacl/fp32_grad/pooling_grad.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/fp32_grad/pooling_grad.h" +#include "nnacl_c/kernel/pooling.h" #include "src/litert/kernel/cpu/fp32_grad/pooling_grad.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc index 5e0be9ef..b04e75c2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_grad_fp32_tests.cc @@ -22,7 +22,7 @@ #include "src/common/utils.h" #include "src/common/file_utils.h" #include "src/litert/kernel/cpu/fp32_grad/softmax_grad.h" -#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl_c/fp32_grad/softmax_grad.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc index f4463fce..8eb3c104 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc index eb560385..72abe37c 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/batchnorm_int8_test.cc @@ -17,8 +17,8 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/int8/batchnorm_int8.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/int8/batchnorm_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc index 52d09191..ca25ffd2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc index 67159248..0ba9e2a7 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/conv_1x1_int8_tests.cc @@ -17,8 +17,8 @@ #include "common/common_test.h" #include "src/executor/kernel_exec.h" #include "src/common/file_utils.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/common_func.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/common_func.h" #include "src/litert/kernel/cpu/int8/convolution_1x1_int8.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc index a7f57980..a301fba6 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index fe80324f..80c3d256 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -20,9 +20,9 @@ #include "common/common_test.h" #include "src/common/file_utils.h" #include "src/litert/kernel_registry.h" -#include "nnacl/pack.h" -#include "nnacl/fp32/matmul_fp32.h" -#include "nnacl/int8/deconv_int8.h" +#include "nnacl_c/pack.h" +#include "nnacl_c/fp32/matmul_fp32.h" +#include "nnacl_c/int8/deconv_int8.h" #include "src/litert/kernel/cpu/int8/deconvolution_int8.h" using mindspore::lite::DeviceType; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index 64bf2768..21e9c732 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -17,8 +17,8 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/fullconnection_int8.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc index fc5376dc..c3bbcf72 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gatherNd_int8_test.cc @@ -16,9 +16,9 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/fp32/gatherNd_fp32.h" -#include "nnacl/int8/gatherNd_int8.h" -#include "nnacl/gather_nd_parameter.h" +#include "nnacl_c/fp32/gatherNd_fp32.h" +#include "nnacl_c/int8/gatherNd_int8.h" +#include "nnacl_c/gather_nd_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc index f136ea84..f2171064 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/gather_int8_test.cc @@ -16,8 +16,8 @@ #include #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/gather_parameter.h" -#include "nnacl/int8/gather_int8.h" +#include "nnacl_c/gather_parameter.h" +#include "nnacl_c/int8/gather_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc index cd3a3f9b..ef03b3f5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc @@ -18,7 +18,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/litert/kernel/cpu/int8/hswish_int8.h" #include "src/litert/kernel_registry.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc index 9109f958..b8ea47a5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/l2_norm_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/l2_norm_parameter.h" +#include "nnacl_c/l2_norm_parameter.h" namespace mindspore { class TestL2NormInt8 : public mindspore::CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 588f11fd..fde391e5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -18,9 +18,9 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/matmul_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/common_func.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/common_func.h" +#include "nnacl_c/int8/matmul_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc index 2da74fa2..e42c1122 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc @@ -18,11 +18,11 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/mul_parameter.h" +#include "nnacl_c/mul_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc index d8cbc4bb..1d94b775 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -20,7 +20,7 @@ #include "src/litert/tensor_category.h" #include "common/common_test.h" #include "src/common/file_utils.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" #include "src/litert/kernel/cpu/int8/pad_int8.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc index 039e37ef..3899c22a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/power_int8_tests.cc @@ -19,7 +19,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/power_int8.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc index e87b3859..0e55eb45 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc index d362571a..576225b6 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/base/quant_dtype_cast.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc index 88099503..c9e038eb 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reduce_int8_tests.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "src/tensor.h" #include "src/litert/kernel_registry.h" -#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl_c/fp32/reduce_fp32.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc index 794f0e05..a765e82f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc index 9777c6b6..c79ef65b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_bilinear_int8_tests.cc @@ -19,7 +19,7 @@ #include "src/tensor.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc index d1b7985d..c664a889 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/resize_nearest_neighbor_int8_tests.cc @@ -19,7 +19,7 @@ #include "src/tensor.h" #include "common/common_test.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc index 4d16b9a9..2df0fbc2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/scale_int8.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "src/tensor.h" #include "src/litert/kernel_registry.h" -#include "nnacl/int8/scale_int8.h" +#include "nnacl_c/int8/scale_int8.h" namespace mindspore { using mindspore::lite::LiteQuantParam; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc index 251c68d6..148ea7b5 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc index cd58822b..3893a239 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -19,7 +19,7 @@ #include "schema/inner/model_generated.h" #include "common/common_test.h" #include "src/litert/kernel/cpu/int8/softmax_int8.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc index 4a7a0e86..d3ad1f6b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/space_to_batch_int8_tests.cc @@ -15,7 +15,7 @@ */ #include #include "common/common_test.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc index f528a8a8..7fe33d32 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc index 46b8c22a..51b6102b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc @@ -18,7 +18,7 @@ #include "schema/inner/model_generated.h" #include "src/common/log_adapter.h" #include "common/common_test.h" -#include "nnacl/squeeze_parameter.h" +#include "nnacl_c/squeeze_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc index 2e6051f2..173e0136 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -18,7 +18,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/fp32/topk_fp32.h" +#include "nnacl_c/fp32/topk_fp32.h" #include "src/litert/kernel_registry.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc index a31ba896..d509b258 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc @@ -17,7 +17,7 @@ #include #include "schema/inner/model_generated.h" #include "common/common_test.h" -#include "nnacl/unsqueeze_parameter.h" +#include "nnacl_c/unsqueeze_parameter.h" #include "src/litert/kernel_registry.h" #include "src/executor/kernel_exec.h" #include "src/tensor.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc index c79d737b..6644e069 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/string/normalize.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel/cpu/string/skip_gram.h" #include "src/litert/kernel_registry.h" -#include "nnacl/skip_gram_parameter.h" +#include "nnacl_c/skip_gram_parameter.h" #include "src/common/file_utils.h" #include "common/common_test.h" #include "src/common/log_adapter.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc index 2b176a0f..a3d61223 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/cuda/batchtospace_tests.cc @@ -18,7 +18,7 @@ #include "schema/ops_generated.h" #include "src/extendrt/kernel/cuda/batchtospace.h" #include "ut/src/extendrt/kernel/cuda/common.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore { class CudaTest_BatchToSpace : public CommonTest { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc index bb53b44e..bb591e81 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc index 0e614fbb..7f2b3180 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arg_min_max_parameter.h" +#include "nnacl_c/arg_min_max_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc index 99aaf1f6..6ebfd6e1 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc index a5c56665..19633ff7 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc index d55c7856..76b0edf2 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc @@ -15,7 +15,7 @@ */ #include #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/batch_to_space_parameter.h" +#include "nnacl_c/batch_to_space_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc index 6121eff5..ad411f94 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc index 2305511f..a7783bf9 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.cc @@ -18,7 +18,7 @@ #include #include "src/litert/kernel_registry.h" #include "src/litert/kernel/opencl/opencl_subgraph.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" using mindspore::kernel::KernelExec; using mindspore::kernel::OpenCLSubGraph; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h index a0b5be1e..81318af6 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/common.h @@ -23,7 +23,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/dtype/type_id.h" #include "src/tensor.h" #include "src/litert/tensor_category.h" diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 37536eb3..c1aef222 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc index e45ec806..57ec7dff 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index e143c1ee..a7a5c405 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc index bc4e9b78..14569e4d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/crop_tests.cc @@ -15,7 +15,7 @@ */ #include #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/crop_parameter.h" +#include "nnacl_c/crop_parameter.h" namespace mindspore::lite::opencl::test { class TestOpenCL_Crop : public CommonTest {}; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index a0988660..42573cfc 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc index 1a4aebe5..4ca48a2e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc index d9c6a487..12cf982e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc index 6b8e23b2..c1b95dbd 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 24a773c8..a84b161f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc index d741989a..d4c1659b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/one_hot_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl_c/fp32/one_hot_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc index e376e922..07fb93b3 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc index 731c3a4e..6975db16 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/pooling_parameter.h" +#include "nnacl_c/pooling_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc index 8915cff7..688ece36 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/prelu_parameter.h" +#include "nnacl_c/prelu_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc index 538f9ce1..67f0924a 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc index 38bf78c0..c623f643 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/reshape_parameter.h" +#include "nnacl_c/reshape_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc index e849da46..ddfccc45 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc index c7423032..bf2b99ca 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/scale_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc index 0dc26b3d..fde49b6f 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "nnacl/slice_parameter.h" +#include "nnacl_c/slice_parameter.h" #include "ut/src/runtime/kernel/opencl/common.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index 420e341c..b05bb9b4 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc index 91da7823..8784b57e 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl_c/fp32/space_to_batch_fp32.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc index 3b6f31c3..0265de50 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/space_to_depth_tests.cc @@ -14,9 +14,9 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/space_to_depth_parameter.h" -#include "nnacl/base/space_to_depth_base.h" -#include "nnacl/depth_to_space_parameter.h" +#include "nnacl_c/space_to_depth_parameter.h" +#include "nnacl_c/base/space_to_depth_base.h" +#include "nnacl_c/depth_to_space_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc index 1aa86bcb..344f0876 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/sparse_to_dense_parameter.h" +#include "nnacl_c/sparse_to_dense_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc index 4e1ccbe2..ce1e3ef1 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/split_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc index 8fa43658..1a2ed573 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc index bd6b3d48..a80a328d 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 12c14702..4a399961 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "ut/src/runtime/kernel/opencl/common.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::opencl::test { diff --git a/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc b/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc index e8eb57e9..ea1a5350 100644 --- a/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/runtime_pass_tests.cc @@ -17,10 +17,10 @@ #include "src/executor/kernel_exec.h" #include "src/litert/kernel_registry.h" #include "src/litert/runtime_pass.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/instance_norm_parameter.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc index d5d7a399..de4ee5bd 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/activation_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc index 39260b43..5f304d49 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/add_concat_act_fusion_inout_test.cc @@ -23,7 +23,7 @@ #include "include/backend/optimizer/optimizer.h" #include "include/backend/optimizer/pass_manager.h" #include "tools/optimizer/fusion/add_concat_activation_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/add_fusion.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc index eb9a2eb1..c53b7eaf 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/conv_activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc index 8725b492..be04c79d 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc index 69b92363..ac5766cf 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc @@ -20,7 +20,7 @@ #include "src/common/log_adapter.h" #include "ir/func_graph.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { ValueNodePtr ConvFusionInoutTest::CreateConvPrimitiveValue() { diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc index e9781501..e3fb942e 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/conv2d_fusion.h" #include "infer/make_tuple.h" #include "infer/return.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc index 1a509bfd..e6ea162d 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_act_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/matmul_activation_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "infer/cxx_api/activation.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc index fa531e8e..fefb608d 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.cc @@ -20,7 +20,7 @@ #include "src/common/log_adapter.h" #include "ir/func_graph.h" #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { CNodePtr MatMulFusionInoutTest::AddMatMul(const FuncGraphPtr &graph, const AnfNodePtr &input1, const AnfNodePtr &input2, diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h index d627f5a5..57865288 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h @@ -20,7 +20,7 @@ #include #include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h" #include "ir/anf.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "include/backend/optimizer/pass.h" #include "include/backend/optimizer/optimizer.h" #include "include/backend/optimizer/pass_manager.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc index 3efe5807..5e6ac368 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_mul_fusion_inout_test.cc @@ -18,7 +18,7 @@ #include #include "tools/optimizer/fusion/matmul_mul_fusion.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/mat_mul_fusion.h" diff --git a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc index 961c8be5..1d40ce2c 100644 --- a/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc +++ b/mindspore-lite/test/ut/tools/optimizer/fusion/fusion_inout_test/trans_matmul_fusion_inout_test.cc @@ -19,7 +19,7 @@ #include "tools/optimizer/fusion/transpose_matmul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "test/ut/tools/optimizer/fusion/fusion_inout_test/matmul_fusion_inout_test.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/tools/benchmark/CMakeLists.txt b/mindspore-lite/tools/benchmark/CMakeLists.txt index 0af2d403..dcb83dd2 100644 --- a/mindspore-lite/tools/benchmark/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark/CMakeLists.txt @@ -61,7 +61,7 @@ if(MSLITE_EXPORT_COMPUTE_IR) set(BENCHMARK_LINK_LIB ${BENCHMARK_LINK_LIB} mindspore_lite_drawer) endif() -include_directories(${OPS_DIR}/kernel/cpu) +include_directories(${NNACL_DIR}/../) set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc @@ -69,7 +69,7 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/config_file.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c + ${NNACL_DIR}/nnacl_common.c ) include_directories(${TOP_DIR}/mindspore-lite) diff --git a/mindspore-lite/tools/benchmark/benchmark_base.h b/mindspore-lite/tools/benchmark/benchmark_base.h index ab373b25..d70491ab 100644 --- a/mindspore-lite/tools/benchmark/benchmark_base.h +++ b/mindspore-lite/tools/benchmark/benchmark_base.h @@ -41,7 +41,7 @@ #include "src/common/utils.h" #include "ir/dtype/type_id.h" #include "schema/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { #define BENCHMARK_LOG_ERROR(str) \ diff --git a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc index 0d902d2d..4b26f642 100644 --- a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc @@ -26,7 +26,7 @@ #include "src/common/common.h" #include "src/tensor.h" #include "tools/common/string_util.h" -#include "nnacl/nnacl_common.h" +#include "nnacl_c/nnacl_common.h" #ifdef ENABLE_ARM64 #include #include diff --git a/mindspore-lite/tools/benchmark_train/CMakeLists.txt b/mindspore-lite/tools/benchmark_train/CMakeLists.txt index a915c581..75ce0407 100644 --- a/mindspore-lite/tools/benchmark_train/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark_train/CMakeLists.txt @@ -5,7 +5,7 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ) - +include_directories(${NNACL_DIR}/../) set(TEST_SRC ${CMAKE_CURRENT_SOURCE_DIR}/main.cc ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc diff --git a/mindspore-lite/tools/common/func_graph_subgraph.cc b/mindspore-lite/tools/common/func_graph_subgraph.cc index 76310685..6abd2eda 100644 --- a/mindspore-lite/tools/common/func_graph_subgraph.cc +++ b/mindspore-lite/tools/common/func_graph_subgraph.cc @@ -27,7 +27,7 @@ #include "tools/common/graph_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/partial_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/common/graph_util.cc b/mindspore-lite/tools/common/graph_util.cc index 43acd95b..3e407ab8 100644 --- a/mindspore-lite/tools/common/graph_util.cc +++ b/mindspore-lite/tools/common/graph_util.cc @@ -28,7 +28,7 @@ #include "tools/common/tensor_util.h" #include "src/common/log_adapter.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/make_tuple.h" #include "tools/converter/converter_context.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/common/graph_util.h b/mindspore-lite/tools/common/graph_util.h index 6d68c92d..2d15581c 100644 --- a/mindspore-lite/tools/common/graph_util.h +++ b/mindspore-lite/tools/common/graph_util.h @@ -35,7 +35,7 @@ #include "src/common/graph_util.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/node_util.h" #include "tools/converter/cxx_api/converter_para.h" diff --git a/mindspore-lite/tools/common/meta_graph_serializer.cc b/mindspore-lite/tools/common/meta_graph_serializer.cc index 7f30ea7f..0c089604 100644 --- a/mindspore-lite/tools/common/meta_graph_serializer.cc +++ b/mindspore-lite/tools/common/meta_graph_serializer.cc @@ -21,7 +21,7 @@ #endif #include "flatbuffers/flatbuffers.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/dtype/type_id.h" #include "src/common/utils.h" #include "include/errorcode.h" diff --git a/mindspore-lite/tools/common/meta_graph_utils.cc b/mindspore-lite/tools/common/meta_graph_utils.cc index a4378267..7456e613 100644 --- a/mindspore-lite/tools/common/meta_graph_utils.cc +++ b/mindspore-lite/tools/common/meta_graph_utils.cc @@ -19,7 +19,7 @@ #include #include "inner/model_generated.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) { diff --git a/mindspore-lite/tools/common/node_util.cc b/mindspore-lite/tools/common/node_util.cc index c8774ad2..3e26389a 100644 --- a/mindspore-lite/tools/common/node_util.cc +++ b/mindspore-lite/tools/common/node_util.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/infer/switch.h" #include "mindspore/ops/infer/call.h" #include "mindspore/ops/infer/cxx_api/partial_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/common/opengl_util.h b/mindspore-lite/tools/common/opengl_util.h index 2ff70a81..9d13fae8 100644 --- a/mindspore-lite/tools/common/opengl_util.h +++ b/mindspore-lite/tools/common/opengl_util.h @@ -21,7 +21,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #if defined(GPU_OPENCL) && defined(__ANDROID__) && defined(ENABLE_ARM64) #include "EGL/egl.h" diff --git a/mindspore-lite/tools/common/statistic_utils.cc b/mindspore-lite/tools/common/statistic_utils.cc index c1ca1284..a6d42560 100644 --- a/mindspore-lite/tools/common/statistic_utils.cc +++ b/mindspore-lite/tools/common/statistic_utils.cc @@ -16,7 +16,7 @@ #include "tools/common/statistic_utils.h" #if defined(ENABLE_AVX) && defined(__linux__) -#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl_c/intrinsics/ms_simd_cpu_info.h" #ifdef _MSC_VER #include #else diff --git a/mindspore-lite/tools/common/statistic_utils.h b/mindspore-lite/tools/common/statistic_utils.h index f4d8ab00..93e8cfb7 100644 --- a/mindspore-lite/tools/common/statistic_utils.h +++ b/mindspore-lite/tools/common/statistic_utils.h @@ -25,7 +25,7 @@ #include #include "include/errorcode.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindapi/base/type_id.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/common/tensor_util.cc b/mindspore-lite/tools/common/tensor_util.cc index 1e0131ad..753cb372 100644 --- a/mindspore-lite/tools/common/tensor_util.cc +++ b/mindspore-lite/tools/common/tensor_util.cc @@ -19,7 +19,7 @@ #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "abstract/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/tensor_new.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/CMakeLists.txt b/mindspore-lite/tools/converter/CMakeLists.txt index d243a751..f5ff2c2c 100644 --- a/mindspore-lite/tools/converter/CMakeLists.txt +++ b/mindspore-lite/tools/converter/CMakeLists.txt @@ -14,7 +14,7 @@ endif() include(${LITE_DIR}/cmake/ccsrc_converter.cmake) -include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/..) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc diff --git a/mindspore-lite/tools/converter/adapter/acl/common/utils.cc b/mindspore-lite/tools/converter/adapter/acl/common/utils.cc index e20e5097..4164288a 100644 --- a/mindspore-lite/tools/converter/adapter/acl/common/utils.cc +++ b/mindspore-lite/tools/converter/adapter/acl/common/utils.cc @@ -25,7 +25,7 @@ #include "include/common/utils/utils.h" #include "src/common/log_util.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc index 925123cc..6582cfad 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/custom_infer.cc @@ -20,7 +20,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc index f829263b..3b738d7f 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/flash_attention_infer.cc @@ -19,7 +19,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc b/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc index 4fbf6df8..b85d4eb4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc +++ b/mindspore-lite/tools/converter/adapter/acl/infer/forward_rasterize_infer.cc @@ -19,7 +19,7 @@ #include "include/registry/register_kernel_interface.h" #include "common/log_adapter.h" #include "tools/converter/adapter/acl/common/acl_types.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace kernel { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc index 99c655d6..e49b0abf 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/arithmetic_mapper.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc index 3d53d912..de8ec126 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/cast_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_c.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc index 4962713c..7ae050f4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/constant_of_shape_mapper.cc @@ -24,7 +24,7 @@ #include "ops_utils/op_utils.h" #include "src/common/log_util.h" #include "tools/common/tensor_util.h" -#include "mindspore/ops/kernel/cpu/nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc index 4101bd22..b8971380 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/conv2d_transpose_fusion_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "include/registry/converter_context.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc index 7b47bf6c..2580324c 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/conv_base_mapper.cc @@ -18,7 +18,7 @@ #include #include #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "utils/check_convert_utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc index 72b69e3e..ac976611 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_g.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc index 0e7bdfe7..6420779e 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_fusion_mapper.cc @@ -20,7 +20,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_g.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc index 4afcd138..831afebb 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gru_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "src/common/log_util.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc index 6ea31068..968a9b01 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/lstm_mapper.cc @@ -22,7 +22,7 @@ #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "tools/converter/adapter/acl/common/utils.h" #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc index c7a9807c..04bc0e3c 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops/base_operator.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc index d7ef4e40..fc2813b4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc @@ -21,7 +21,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc index ce89b4d6..74a0453e 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "include/registry/converter_context.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_o.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc index c13b4d04..e9a72850 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc @@ -26,7 +26,7 @@ #include "ops_utils/op_utils.h" #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc index b147832a..902af110 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/quant_dtype_cast_mapper.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "mindspore/ops/op_def/op_name.h" #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc index 5faba4a5..4a363a58 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/reshape_mapper.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/adapter/acl/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_r.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc index 1c8af3f0..d52622b7 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/resize_mapper.cc @@ -28,7 +28,7 @@ #include "ops_utils/op_utils.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc index 1d58cc8e..9b252148 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/stridedslice_mapper.cc @@ -20,7 +20,7 @@ #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "include/registry/converter_context.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_s.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc index bb0d07ca..2f02609c 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/tile_fusion_mapper.cc @@ -22,7 +22,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc index 60a48800..99e30517 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/topk_fusion_mapper.cc @@ -23,7 +23,7 @@ #include "src/common/log_util.h" #include "infer/topk.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc index b99b8783..4dd96725 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/transpose_mapper.cc @@ -21,7 +21,7 @@ #include #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_name_t.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc index d75fd3ce..fd383b25 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/upsample_mapper.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/adapter/acl/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc index f73d66ec..0aff2c4f 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/where_mapper.cc @@ -19,7 +19,7 @@ #include "tools/converter/adapter/acl/common/utils.h" #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index f1752372..197203c0 100644 --- a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -38,7 +38,7 @@ #include "infer/standard_normal.h" #include "infer/tuple_get_item.h" #include "cxx_api/model/acl/model_converter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "src/common/log_util.h" #include "src/common/file_utils.h" diff --git a/mindspore-lite/tools/converter/anf_transform.cc b/mindspore-lite/tools/converter/anf_transform.cc index 15010562..406986a0 100644 --- a/mindspore-lite/tools/converter/anf_transform.cc +++ b/mindspore-lite/tools/converter/anf_transform.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/anf_transform_for_ge.cc b/mindspore-lite/tools/converter/anf_transform_for_ge.cc index 9f6b9bf0..41d8cc61 100644 --- a/mindspore-lite/tools/converter/anf_transform_for_ge.cc +++ b/mindspore-lite/tools/converter/anf_transform_for_ge.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_adapter.h" #include "tools/converter/optimizer_manager.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc b/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc index aab7d32d..1d47c900 100644 --- a/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc +++ b/mindspore-lite/tools/converter/config_parser/acl_option_param_parser.cc @@ -20,7 +20,7 @@ #include "tools/common/string_util.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/converter.cc b/mindspore-lite/tools/converter/converter.cc index 5d829150..196d7287 100644 --- a/mindspore-lite/tools/converter/converter.cc +++ b/mindspore-lite/tools/converter/converter.cc @@ -36,7 +36,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/converter/import/mindspore_importer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/coder.h" #include "src/common/prim_util.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/tools/converter/converter_funcgraph.cc b/mindspore-lite/tools/converter/converter_funcgraph.cc index df219c9b..289740dd 100644 --- a/mindspore-lite/tools/converter/converter_funcgraph.cc +++ b/mindspore-lite/tools/converter/converter_funcgraph.cc @@ -34,7 +34,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/converter/import/mindspore_importer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/coder.h" #include "src/common/prim_util.h" #include "src/common/version_manager.h" diff --git a/mindspore-lite/tools/converter/converter_packed_node.cc b/mindspore-lite/tools/converter/converter_packed_node.cc index beb91a4c..45a6a617 100644 --- a/mindspore-lite/tools/converter/converter_packed_node.cc +++ b/mindspore-lite/tools/converter/converter_packed_node.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "src/litert/kernel/cpu/fp32/matmul_fp32.h" #include "src/litert/kernel/cpu/nnacl/nnacl_kernel.h" -#include "nnacl/kernel/matmul_struct.h" +#include "nnacl_c/kernel/matmul_struct.h" namespace mindspore { namespace { diff --git a/mindspore-lite/tools/converter/export_model.cc b/mindspore-lite/tools/converter/export_model.cc index d422faa3..9495e215 100644 --- a/mindspore-lite/tools/converter/export_model.cc +++ b/mindspore-lite/tools/converter/export_model.cc @@ -34,7 +34,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/graph/control_flow_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/converter/import/mindir_adjust.cc b/mindspore-lite/tools/converter/import/mindir_adjust.cc index 3435900e..d9a2eef1 100644 --- a/mindspore-lite/tools/converter/import/mindir_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_adjust.cc @@ -26,7 +26,7 @@ #include "src/common/log_adapter.h" #include "src/common/quant_utils.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/fake_quant_param.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc b/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc index 502dd063..7b38e032 100644 --- a/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc +++ b/mindspore-lite/tools/converter/import/mindir_control_flow_adjust.cc @@ -27,7 +27,7 @@ #include "tools/common/node_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" namespace { constexpr const int kSwitchTruePartialIndex = 2; diff --git a/mindspore-lite/tools/converter/import/mindspore_importer.cc b/mindspore-lite/tools/converter/import/mindspore_importer.cc index c496034a..233f9231 100644 --- a/mindspore-lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore-lite/tools/converter/import/mindspore_importer.cc @@ -37,7 +37,7 @@ #include "tools/converter/parser/unify_format.h" #include "tools/converter/parser/lstm_adjust_pass.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/common.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/converter/import/primitive_adjust.cc b/mindspore-lite/tools/converter/import/primitive_adjust.cc index 308d18fb..c9760a74 100644 --- a/mindspore-lite/tools/converter/import/primitive_adjust.cc +++ b/mindspore-lite/tools/converter/import/primitive_adjust.cc @@ -67,7 +67,7 @@ #include "infer/random_standard_normal.h" #include "infer/fill.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/converter/import/remove_public_primitive.cc b/mindspore-lite/tools/converter/import/remove_public_primitive.cc index 26cdfffa..c072441b 100644 --- a/mindspore-lite/tools/converter/import/remove_public_primitive.cc +++ b/mindspore-lite/tools/converter/import/remove_public_primitive.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/structure_ops.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc index a551196d..bb04530d 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc +++ b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc @@ -29,7 +29,7 @@ #include "tools/common/meta_graph_utils.h" #include "include/errorcode.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h index 1281053f..fa1a39b5 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h +++ b/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pattern.h @@ -24,7 +24,7 @@ #include #include "src/common/log_adapter.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index 3913b383..297391a8 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore-lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -22,7 +22,7 @@ #include "tools/converter/optimizer.h" #include "tools/common/graph_util.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 8f650e0f..dde3e922 100644 --- a/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore-lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -30,7 +30,7 @@ #include "tools/common/node_util.h" #include "src/common/string_utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::converter::kFmkTypeTf; namespace { diff --git a/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc b/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc index cee396e3..9883f4ff 100644 --- a/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc +++ b/mindspore-lite/tools/converter/micro/coder/allocator/memory_manager.cc @@ -16,7 +16,7 @@ #include "tools/converter/micro/coder/allocator/memory_manager.h" #include -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/opcoders/op_coder.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc index daafa97a..ab2e8f42 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/common_component.cc @@ -21,7 +21,7 @@ #include "coder/utils/coder_utils.h" #include "coder/log.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/c_api/model_c.h" #include "coder/generator/component/const_blocks/license.h" #include "tools/common/string_util.h" diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc index 7a341581..d0adbc7f 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/const_blocks/debug_utils.cc @@ -41,7 +41,7 @@ const char debug_utils_h[] = R"RAW( #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #define MICRO_INFO(content, args...) \ { printf("[INFO] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); } diff --git a/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc b/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc index 90bd64ca..efbb7964 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/component/train_component.cc @@ -17,7 +17,7 @@ #include "coder/generator/component/train_component.h" #include #include "coder/utils/coder_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "coder/utils/type_cast.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/generator/generator.cc b/mindspore-lite/tools/converter/micro/coder/generator/generator.cc index e826b681..7eed5ce7 100644 --- a/mindspore-lite/tools/converter/micro/coder/generator/generator.cc +++ b/mindspore-lite/tools/converter/micro/coder/generator/generator.cc @@ -571,8 +571,8 @@ int Generator::CodeRegKernelHFile() { MS_CHECK_TRUE(!cofs.bad(), "filed to open file"); MS_LOG(INFO) << "write " << reg_kernel_header; cofs << g_hwLicense; - cofs << "#include \"nnacl/tensor_c.h\"\n"; - cofs << "#include \"nnacl/custom_parameter.h\"\n\n"; + cofs << "#include \"nnacl_c/tensor_c.h\"\n"; + cofs << "#include \"nnacl_c/custom_parameter.h\"\n\n"; cofs << KernelRegistry::GetInstance()->GenKernelInterface(kCustomKernelName, kCustomKernelParam) << "\n"; return RET_OK; } diff --git a/mindspore-lite/tools/converter/micro/coder/log.h b/mindspore-lite/tools/converter/micro/coder/log.h index f22ea2b4..68b43c8e 100644 --- a/mindspore-lite/tools/converter/micro/coder/log.h +++ b/mindspore-lite/tools/converter/micro/coder/log.h @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #define MS_CHECK_PTR(ptr) \ do { \ diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc index 5fd77feb..9e887788 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.cc @@ -17,8 +17,8 @@ #include "coder/opcoders/base/conv2d_base_coder.h" #include #include -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/int8/quantize.h" #include "coder/log.h" #include "src/litert/tensor_category.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h index b03d9e14..dd5ef61c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/conv2d_base_coder.h @@ -23,7 +23,7 @@ #include #include "coder/opcoders/op_coder.h" #include "src/litert/kernel/cpu/base/layout_transform.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro { class Conv2DBaseCoder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc index 738441b9..ef6131fa 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.cc @@ -15,8 +15,8 @@ */ #include "coder/opcoders/base/detection_post_process_base_coder.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -125,8 +125,8 @@ int DetectionPostProcessBaseCoder::AllocateBuffer() { int DetectionPostProcessBaseCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/detection_post_process_parameter.h", - "nnacl/fp32/detection_post_process_fp32.h", + "nnacl_c/detection_post_process_parameter.h", + "nnacl_c/fp32/detection_post_process_fp32.h", "wrapper/base/detection_post_process_base_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h index fc2b216a..0099d8f9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/detection_post_process_base_coder.h @@ -22,7 +22,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/detection_post_process_parameter.h" +#include "nnacl_c/detection_post_process_parameter.h" #include "coder/opcoders/serializers/serializer.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc index be1f041d..6cebd9dc 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.cc @@ -123,7 +123,7 @@ int DTypeCastCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/cast_base.h", + "nnacl_c/base/cast_base.h", }, { "cast_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h index 03f9e54f..fcd1421e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/dtype_cast_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" #include "coder/opcoders/serializers/serializer.h" namespace mindspore::lite::micro { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h index c11a703b..51be7340 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/full_connection_base_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro { class FullConnectionBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc index 8d08edd5..71c2d205 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.cc @@ -47,7 +47,7 @@ int QuantDTypeCastCoder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/quant_dtype_cast_int8.h", + "nnacl_c/int8/quant_dtype_cast_int8.h", }, { "quant_dtype_cast_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h index 3e24703e..0aa01322 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/quant_dtype_cast_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl_c/int8/quant_dtype_cast_int8.h" namespace mindspore::lite::micro { class QuantDTypeCastCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h index af8d6d21..3b6b6c2f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/reduce_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/reduce_parameter.h" +#include "nnacl_c/reduce_parameter.h" namespace mindspore::lite::micro { class ReduceBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h index bc6e5c5f..a6f78d7b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/resize_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/resize_parameter.h" +#include "nnacl_c/resize_parameter.h" namespace mindspore::lite::micro { class ResizeBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h index 0ec49451..c295082f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/softmax_base_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro { class SoftmaxBaseCoder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc index ee887342..c319ee33 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc @@ -40,7 +40,7 @@ int StackFP32Coder::ReSize() { int StackFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/stack_base.h", + "nnacl_c/base/stack_base.h", }, { "stack_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h index 08074332..33f98382 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/stack_parameter.h" +#include "nnacl_c/stack_parameter.h" namespace mindspore::lite::micro::nnacl { class StackFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc index 44622279..07f63c93 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc @@ -176,7 +176,7 @@ int StridedSliceBaseCoder::DoCode(CoderContext *ctx) { inner_size_ = GetInnerSize(input_tensor_->data_type(), inner_); Collect(ctx, { - "nnacl/fp32/strided_slice_fp32.h", + "nnacl_c/fp32/strided_slice_fp32.h", "wrapper/base/strided_slice_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h index 5e2f2d6f..87fbdf7b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.h @@ -17,7 +17,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" namespace mindspore::lite::micro { class StridedSliceBaseCoder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc index 11b27a4e..b7e87358 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc @@ -81,7 +81,7 @@ int StridedSliceDynamicBaseCoder::Prepare(CoderContext *context) { int StridedSliceDynamicBaseCoder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/fp32/strided_slice_fp32.h", + "nnacl_c/fp32/strided_slice_fp32.h", }, { "strided_slice_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h index 1368c4e0..d7553b6e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" -#include "nnacl/strided_slice_parameter.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/strided_slice_parameter.h" +#include "nnacl_c/kernel/strided_slice.h" namespace mindspore::lite::micro { class StridedSliceDynamicBaseCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc index 3a6912df..e06094e0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.cc @@ -60,7 +60,7 @@ int UnstackBaseCoder::Prepare(CoderContext *context) { int UnstackBaseCoder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/base/unstack_base.h", + "nnacl_c/base/unstack_base.h", }, { "unstack_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h index e095be3c..84c2e938 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/base/unstack_base_coder.h @@ -17,8 +17,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_UNSTACK_BASE_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/unstack_base.h" -#include "nnacl/op_base.h" +#include "nnacl_c/base/unstack_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro { class UnstackBaseCoder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc index 31fbde7a..2e92d877 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc @@ -23,8 +23,8 @@ #include "coder/opcoders/serializers/serializer.h" #include "coder/utils/common.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/arithmetic_parameter.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/arithmetic_parameter.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc index 431e7842..b7f0a556 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::cmsis { int Conv2DBaseCoder::SetQuantArgs() { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h index 7fb49373..c9bb04ee 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::cmsis { class Conv2DBaseCoder : public micro::Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h index c0409d7c..1ead129a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/conv2d_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/cmsis-nn/int8/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::cmsis { class Conv2DInt8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h index 8e4bdec4..aa79e03b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/fullconnection_int8_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/base/full_connection_base_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::cmsis { class FullConnectionInt8Coder final : public FullConnectionBaseCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc index 5f0fe01b..4b6bb69e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/mul_int8_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/cmsis-nn/int8/mul_int8_coder.h" #include #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/file_collector.h" using mindspore::schema::PrimitiveType_MulFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h index 05800968..daa157ed 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/cmsis-nn/int8/pooling_int8_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" namespace mindspore::lite::micro::cmsis { class PoolingInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc index 2a40de01..f749d362 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/custom/custom_coder.cc @@ -23,7 +23,7 @@ #include "tools/converter/micro/coder/opcoders/op_coder_register.h" #include "tools/converter/micro/coder/opcoders/kernel_registry.h" #include "src/common/prim_util.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/custom_parameter.h" using mindspore::schema::PrimitiveType_Custom; @@ -151,7 +151,7 @@ void CustomCoder::FreeTensors(Serializer *code, std::string array_name, size_t t } int CustomCoder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/custom_parameter.h", "nnacl/tensor_c.h", "src/registered_kernel.h"}, {}); + Collect(context, {"nnacl_c/custom_parameter.h", "nnacl_c/tensor_c.h", "src/registered_kernel.h"}, {}); Serializer code; MS_CHECK_RET_CODE(TransformTensors(&code, "inputs", input_tensors_), "Transform input tensors error!"); MS_CHECK_RET_CODE(TransformTensors(&code, "outputs", output_tensors_), "Transform output tensors error!"); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc index b16b5e6d..d01d4c11 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc @@ -33,7 +33,7 @@ int ActivationDynamicFP16Coder::Prepare(CoderContext *const context) { int ActivationDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "activation_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc index 0fdf0a7f..cbc12059 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_fp16_coder.cc @@ -35,7 +35,7 @@ int ActivationFP16Coder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "activation_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc index 3228e7e2..001a9af1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc @@ -98,8 +98,8 @@ int ArithmeticDynamicFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; Collect(context, { - "nnacl/fp16/arithmetic_fp16.h", - "nnacl/base/broadcast_to.h", + "nnacl_c/fp16/arithmetic_fp16.h", + "nnacl_c/base/broadcast_to.h", }, { "arithmetic_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h index ca958d73..6451cad2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h @@ -20,11 +20,11 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/cast_base.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/base/cast_base.h" +#include "nnacl_c/arithmetic_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" -#include "nnacl/broadcast_to_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc index c31f7172..adaa95eb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" -#include "nnacl/broadcast_to_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" #include "base/float16.h" namespace mindspore::lite::micro::nnacl { @@ -105,8 +105,8 @@ int ArithmeticFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; Collect(context, { - "nnacl/fp16/arithmetic_fp16.h", - "nnacl/base/broadcast_to.h", + "nnacl_c/fp16/arithmetic_fp16.h", + "nnacl_c/base/broadcast_to.h", }, { "arithmetic_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc index 4d7fe34a..02aaa731 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.cc @@ -59,7 +59,7 @@ int ArithmeticSelfFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/arithmetic_self_fp16.h", + "nnacl_c/fp16/arithmetic_self_fp16.h", }, { "arithmetic_self_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h index e95229f7..aa2db705 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_self_fp16_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_Abs; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc index 063ef0d8..fa40afcb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc @@ -40,7 +40,7 @@ int ConcatDynamicFP16Coder::Prepare(CoderContext *const context) { int ConcatDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h index 6408403b..67825946 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc index fd969963..52b11e20 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc @@ -42,7 +42,7 @@ int ConcatFP16Coder::ReSize() { int ConcatFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h index 6428ac6f..e3cafe32 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/concat_fp32_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatFP16Coder final : public ConcatFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc index 2c1e01af..7ea7893d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc @@ -18,9 +18,9 @@ #include "src/common/version_manager.h" #include "src/common/tensor_util.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/infer/conv2d_infer.h" #include "coder/shape_info_container.h" #include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h index 78dd3ebf..6852962f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateDynamicFP16Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc index 117ba90f..9d4cb9aa 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc @@ -18,9 +18,9 @@ #include "src/common/version_manager.h" #include "src/common/tensor_util.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" -#include "nnacl/infer/conv2d_infer.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" +#include "nnacl_c/infer/conv2d_infer.h" #include "coder/opcoders/nnacl/fp16/convolution_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.h" #include "coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h index 094b01b6..923bbd29 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateFP16Coder : public ConvDelegateCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc index 690d0e1b..05654d88 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc @@ -90,8 +90,8 @@ void ConvolutionDepthwise3x3FP16Coder::CollectFilesForFunc(CoderContext *const c } Collect(context, { - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", }, { "conv_depthwise_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h index 50c07f4e..fee11873 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc index 0140e2a4..72ccac96 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc @@ -65,9 +65,9 @@ void ConvolutionDepthwiseFP16Coder::CollectFilesForFunc(CoderContext *const cont } Collect(context, { - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/activation_fp16.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/activation_fp16.h", }, { "conv_depthwise_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc index 10acc376..6a0aae33 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc @@ -95,9 +95,9 @@ void ConvolutionDepthwiseSWFP16Coder::CollectFilesForFunc(CoderContext *const co } Collect(context, { - "nnacl/fp32/conv_depthwise_fp32.h", - "nnacl/fp16/conv_depthwise_fp16.h", - "nnacl/fp16/pack_fp16.h", + "nnacl_c/fp32/conv_depthwise_fp32.h", + "nnacl_c/fp16/conv_depthwise_fp16.h", + "nnacl_c/fp16/pack_fp16.h", }, { "conv_depthwise_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc index 06b3e4ef..2c9544fc 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" #include -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/coder_utils.h" @@ -231,11 +231,11 @@ void Convolution1x1DynamicFP16Coder::CollectFilesForFunc(CoderContext *const con } Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", - "nnacl/base/conv1x1_base.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", + "nnacl_c/base/conv1x1_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h index abc34ad2..a0558c9a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h @@ -19,8 +19,8 @@ #include #include -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/op_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc index 6c1aedea..b997b840 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h" #include #include -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -148,11 +148,11 @@ void Convolution1x1FP16Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", - "nnacl/base/conv1x1_base.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", + "nnacl_c/base/conv1x1_base.h", "wrapper/base/micro_parameter.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h index f1e88619..6c5afb1f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc index 19d1ab92..4f2782ea 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -131,10 +131,10 @@ void ConvolutionDynamicFP16Coder::CollectFilesForFunc(CoderContext *const contex }); Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h index 1ba47530..617e904b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/op_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc index 43f0e00e..874dda1c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -101,10 +101,10 @@ void ConvolutionFP16Coder::CollectFilesForFunc(CoderContext *const context) { }); Collect(context, { - "nnacl/fp16/matmul_fp16.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", - "nnacl/fp16/conv_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", + "nnacl_c/fp16/conv_fp16.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h index 2c92876f..42847830 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/nnacl/fp32/convolution_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc index b03e484e..68d4b74b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h" #include -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -205,10 +205,10 @@ std::string ConvolutionWinogradFP16Coder::GetOutputTransFunc(int input_unit, int void ConvolutionWinogradFP16Coder::CollectFilesForFunc(CoderContext *const context) { Collect(context, - {"nnacl/fp16/conv_fp16.h", "nnacl/fp16/winograd_utils_fp16.h", - "nnacl/fp16/winograd_transform_fp16.h" - "nnacl/base/minimal_filtering_generator.h" - "nnacl/base/conv_common_base.h"}, + {"nnacl_c/fp16/conv_fp16.h", "nnacl_c/fp16/winograd_utils_fp16.h", + "nnacl_c/fp16/winograd_transform_fp16.h" + "nnacl_c/base/minimal_filtering_generator.h" + "nnacl_c/base/conv_common_base.h"}, { "conv_fp16.c", "winograd_utils_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h index 824dbc5d..8f46a73a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { typedef struct TransFuncFp16Str { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc index 5470b56a..19bdfad1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc @@ -21,7 +21,7 @@ using mindspore::schema::PrimitiveType_Custom; namespace mindspore::lite::micro::nnacl { void CustomGruFP16Coder::InitNnaclFile(CoderContext *const context) { - Collect(context, {"nnacl/fp16/custom_gru_fp16.h"}, + Collect(context, {"nnacl_c/fp16/custom_gru_fp16.h"}, {"custom_gru_fp16.c", "pack_fp16.c", "matmul_fp16.c", "arithmetic_fp16.c", "activation_fp16.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h index ef38f1f0..1d16fce6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" namespace mindspore::lite::micro::nnacl { class CustomGruFP16Coder : public CustomGruFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc index d6be5f03..dfbdf43e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc @@ -99,16 +99,16 @@ void DeConvolutionFP16Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp16/deconv_fp16.h", - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", - "nnacl/fp16/common_func_fp16.h", - "nnacl/base/minimal_filtering_generator.h", - "nnacl/conv_parameter.h", - "nnacl/common_func.h", - "nnacl/matmul_parameter.h", + "nnacl_c/fp16/deconv_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", + "nnacl_c/fp16/common_func_fp16.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/common_func.h", + "nnacl_c/matmul_parameter.h", "wrapper/base/micro_parameter.h", - "nnacl/op_base.h", + "nnacl_c/op_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h index 664fc24b..683c2b56 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc index 638a3e0b..9ac0bb00 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc @@ -33,7 +33,7 @@ int LayerNormFP16Coder::Prepare(CoderContext *const context) { int LayerNormFP16Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("layer_norm_compute_parm", compute_); - Collect(context, {"nnacl/fp16/layer_norm_fp16.h"}, {"layer_norm_fp16.c"}); + Collect(context, {"nnacl_c/fp16/layer_norm_fp16.h"}, {"layer_norm_fp16.c"}); if (output_tensors_.size() == C3NUM) { code.CodeFunction("LayerNormFp16", input_tensor_, input_tensors_.at(SECOND_INPUT), input_tensors_.at(THIRD_INPUT), output_tensor_, output_tensors_.at(SECOND_INPUT), output_tensors_.at(THIRD_INPUT), diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h index df025e3c..7be90eb6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h" -#include "nnacl/layer_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" namespace mindspore::lite::micro::nnacl { class LayerNormFP16Coder final : public LayerNormFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc index f2a1eaa3..213144f2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc @@ -198,8 +198,8 @@ int LstmFP16Coder::Prepare(CoderContext *const context) { int LstmFP16Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/lstm_parameter.h", - "nnacl/fp16/lstm_fp16.h", + "nnacl_c/lstm_parameter.h", + "nnacl_c/fp16/lstm_fp16.h", }, { "lstm_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h index fbaa3bd0..8ab283b9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/lstm_fp32_coder.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::lite::micro::nnacl { class LstmFP16Coder final : public LstmFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc index 66b6db4b..86dda658 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc @@ -46,7 +46,7 @@ int LstmMindirDynamicFP16Coder::Prepare(CoderContext *const context) { } int LstmMindirDynamicFP16Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/lstm_parameter.h", "nnacl/fp16/lstm_fp16.h"}, + Collect(context, {"nnacl_c/lstm_parameter.h", "nnacl_c/fp16/lstm_fp16.h"}, {"lstm_fp16.c", "activation_fp16.c", "arithmetic_fp16.c", "matmul_fp16.c", "pack_fp16.c"}, {"MatmulBaseFp16Neon.S"}); auto ret = InitInputWeightBias(context); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h index a348b917..1526547c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" #include "coder/opcoders/op_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc index f1b65488..a82db8fb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc @@ -208,8 +208,8 @@ int MatMulDynamicFP16BaseCoder::ComputeMatrixAWorkspace() { int MatMulDynamicFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", }, { "pack_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h index 250fb96b..29de2e3c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h @@ -21,7 +21,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h index 768143b2..6f1a38c0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulDynamicFP16Coder final : public MatMulDynamicFP16BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc index 1bf4e8ca..e188b928 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc @@ -227,8 +227,8 @@ int MatMulFP16BaseCoder::Prepare(CoderContext *const context) { int MatMulFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pack_fp16.h", - "nnacl/fp16/matmul_fp16.h", + "nnacl_c/fp16/pack_fp16.h", + "nnacl_c/fp16/matmul_fp16.h", }, { "pack_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h index b56a8c12..47398fbe 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP16BaseCoder : public MatMulFP32BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h index c5ea36cd..dd381ced 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP16Coder final : public MatMulFP16BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc index 8fd99da3..49abf35b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc @@ -65,7 +65,7 @@ int PoolingDynamicFP16Coder::Prepare(CoderContext *const context) { int PoolingDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/pooling_fp16.h", + "nnacl_c/fp16/pooling_fp16.h", }, { "pooling_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h index d36f9356..79b67ccc 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc index b043a881..513f1ae3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_fp16_coder.cc @@ -56,8 +56,8 @@ int PoolingFP16Coder::DoCode(CoderContext *const context) { float maxf = FLT16_MAX; Collect(context, { - "nnacl/fp16/pooling_fp16.h", - "nnacl/kernel/pooling.h", + "nnacl_c/fp16/pooling_fp16.h", + "nnacl_c/kernel/pooling.h", }, { "pooling_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc index 5d2bf54c..31b5f501 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc @@ -34,7 +34,7 @@ int ReduceFP16Coder::Prepare(CoderContext *const context) { int ReduceFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/reduce_fp16.h", + "nnacl_c/fp16/reduce_fp16.h", }, { "reduce_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc index 18c40b8b..c19ef95f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc @@ -22,7 +22,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/common.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "base/float16.h" using mindspore::schema::PrimitiveType_Resize; @@ -33,8 +33,8 @@ int ResizeFP16Coder::DataTypeLen() { return sizeof(uint16_t); } int ResizeFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/resize_fp16.h", - "nnacl/fp32/resize_fp32.h", + "nnacl_c/fp16/resize_fp16.h", + "nnacl_c/fp32/resize_fp32.h", }, { "resize_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h index 769bb62c..59ae6714 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/resize_fp16_coder.h @@ -23,7 +23,7 @@ #include #include "include/errorcode.h" #include "src/executor/kernel_exec.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" namespace mindspore::lite::micro::nnacl { class ResizeFP16Coder : public ResizeFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc index 0d6790e5..05c905f7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc @@ -50,8 +50,8 @@ int ScaleDynamicFP16Coder::DoCode(CoderContext *const context) { // init struct ScaleParameters Collect(context, { - "nnacl/kernel/scale.h", - "nnacl/fp16/scale_fp16.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp16/scale_fp16.h", }, { "scale_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h index e64286a8..723bf08e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" -#include "nnacl/kernel/scale.h" -#include "nnacl/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" namespace mindspore::lite::micro::nnacl { class ScaleDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc index d03aeabd..ee6c0b7c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.cc @@ -42,9 +42,9 @@ int ScaleFP16Coder::DoCode(CoderContext *const context) { // init struct ScaleParameters Collect(context, { - "nnacl/scale_parameter.h", - "nnacl/kernel/scale.h", - "nnacl/fp16/scale_fp16.h", + "nnacl_c/scale_parameter.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp16/scale_fp16.h", }, { "scale_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h index 5032cf51..b4da8a24 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_fp16_coder.h @@ -20,8 +20,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/fp32/scale_fp32_coder.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" namespace mindspore::lite::micro::nnacl { class ScaleFP16Coder final : public ScaleFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc index a4ffaa0d..e4861206 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc @@ -44,7 +44,7 @@ int SliceDynamicFP16Coder::Prepare(CoderContext *const context) { int SliceDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", }, { "slice_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h index 35b0bde0..6defebcf 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h @@ -21,8 +21,8 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/op_base.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro::nnacl { class SliceDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc index d75c9f67..36517122 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.cc @@ -24,7 +24,7 @@ namespace mindspore::lite::micro::nnacl { int SliceFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", }, { "slice_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h index d6f503ec..ec5fb0a0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_fp16_coder.h @@ -20,7 +20,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" #include "tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/kernel/slice.h" namespace mindspore::lite::micro::nnacl { class SliceFP16Coder final : public SliceFP32Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc index 9c08e9f3..27046968 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc @@ -44,8 +44,8 @@ int SoftmaxDynamicFP16Coder::Prepare(CoderContext *const context) { int SoftmaxDynamicFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/softmax_fp16.h", - "nnacl/fp16/log_softmax_fp16.h", + "nnacl_c/fp16/softmax_fp16.h", + "nnacl_c/fp16/log_softmax_fp16.h", }, { "softmax_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h index 1063969b..4041db0f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h @@ -21,8 +21,8 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/kernel/softmax.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/kernel/softmax.h" namespace mindspore::lite::micro::nnacl { class SoftmaxDynamicFP16Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc index ceea05ed..4bc8e690 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_fp16_coder.cc @@ -40,8 +40,8 @@ int SoftMaxFP16Coder::Prepare(CoderContext *const context) { int SoftMaxFP16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp16/softmax_fp16.h", - "nnacl/fp16/log_softmax_fp16.h", + "nnacl_c/fp16/softmax_fp16.h", + "nnacl_c/fp16/log_softmax_fp16.h", }, { "softmax_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc index 59c8d8b8..7ce12f96 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc @@ -40,9 +40,9 @@ int TransposeDynamicFp16Coder::Prepare(CoderContext *const context) { int TransposeDynamicFp16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp16/transpose_fp16.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp16/transpose_fp16.h", }, { "transpose_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h index b31f1022..62594117 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc index 3d826cae..1ada88d6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc @@ -81,9 +81,9 @@ int TransposeFp16Coder::ResetStatus() { int TransposeFp16Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp16/transpose_fp16.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp16/transpose_fp16.h", }, { "transpose_fp16.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h index ce99f558..91d68a7a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_fp16_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/transpose_fp32_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeFp16Coder final : public TransposeFp32Coder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc index edc442e9..bae2a839 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "coder/opcoders/nnacl/fp32/activation_fp32_coder.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -29,7 +29,7 @@ int ActivationFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/activation_fp32_wrapper.h", - "nnacl/fp32/activation_fp32.h", + "nnacl_c/fp32/activation_fp32.h", }, { "activation_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc index d2394514..04465a07 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc @@ -31,7 +31,7 @@ int AddNFP32Coder::DoCode(CoderContext *const context) { // Get Tensor Pointer Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h index f60a80fe..b71d0818 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "tools/converter/micro/coder/wrapper/base/affine_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc index 9ec429e8..c94b1002 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h" #include #include "coder/opcoders/file_collector.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/parallel.h" #include "coder/log.h" @@ -208,7 +208,7 @@ int ArithmeticFP32Coder::ConstTensorBroadCast(CoderContext *const context) { } FreeConstTileBuff(); NNaclFp32Serializer init_code; - Collect(context, {"wrapper/fp32/arithmetic_fp32_wrapper.h", "nnacl/fp32/arithmetic_fp32.h"}, + Collect(context, {"wrapper/fp32/arithmetic_fp32_wrapper.h", "nnacl_c/fp32/arithmetic_fp32.h"}, {"arithmetic_fp32_wrapper.c", "arithmetic_fp32.c"}); if (input_tensor_->IsConst() && arithmetic_parameter_->in_elements_num0_ != arithmetic_parameter_->out_elements_num_) { @@ -286,7 +286,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { if (arithmetic_opt_run_ == "ElementOptSub" || arithmetic_run_ == "ElementSub") { Collect(context, { - "nnacl/fp32/sub_fp32.h", + "nnacl_c/fp32/sub_fp32.h", }, { "sub_fp32.c", @@ -294,7 +294,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_opt_run_ == "ElementOptAdd" || arithmetic_run_ == "ElementAdd") { Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", @@ -304,7 +304,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_opt_run_ == "ElementOptMul" || arithmetic_run_ == "ElementMul") { Collect(context, { - "nnacl/fp32/mul_fp32.h", + "nnacl_c/fp32/mul_fp32.h", }, { "mul_fp32.c", @@ -312,7 +312,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else if (arithmetic_run_ == "ElementAddRelu") { Collect(context, { - "nnacl/fp32/add_fp32.h", + "nnacl_c/fp32/add_fp32.h", }, { "add_fp32.c", @@ -321,7 +321,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { arithmetic_run_ == "ElementDiv") { Collect(context, { - "nnacl/fp32/div_fp32.h", + "nnacl_c/fp32/div_fp32.h", }, { "div_fp32.c", @@ -329,7 +329,7 @@ void ArithmeticFP32Coder::CollectFilesForFunc(CoderContext *const context) { } else { Collect(context, { - "nnacl/fp32/arithmetic_fp32.h", + "nnacl_c/fp32/arithmetic_fp32.h", }, { "arithmetic_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h index 169ed457..4f35e507 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "wrapper/fp32/arithmetic_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc index 583e1d0d..c423e263 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h" #include #include -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" @@ -68,7 +68,7 @@ int ArithmeticSelfFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/arithmetic_self_fp32.h", + "nnacl_c/fp32/arithmetic_self_fp32.h", }, { "arithmetic_self_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h index 64f57af2..a7c07707 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/arithmetic_self_fp32_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/arithmetic_self_fp32.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/fp32/arithmetic_self_fp32.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { using mindspore::schema::PrimitiveType_Abs; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h index 48f60254..f0f878a7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/assign_add_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class AssignAddFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc index 0df6c6c9..61f6b395 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.cc @@ -16,8 +16,8 @@ #include "coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h" #include #include -#include "nnacl/fp32/batchnorm_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/batchnorm_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -58,11 +58,11 @@ int BatchnormFP32Coder::DoCode(CoderContext *const context) { MS_CHECK_PTR(var_tensor); Collect(context, { - "nnacl/fp32/batchnorm.h", - "nnacl/kernel/batch_norm.h", + "nnacl_c/fp32/batchnorm.h", + "nnacl_c/kernel/batch_norm.h", }, { - "nnacl/fp32/batchnorm.c", + "nnacl_c/fp32/batchnorm.c", }); NNaclFp32Serializer code; code.CodeStruct("bn_struct", batchnorm_struct_); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h index 1f76796a..e7e17606 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/batchnorm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/batch_norm.h" +#include "nnacl_c/kernel/batch_norm.h" namespace mindspore::lite::micro::nnacl { class BatchnormFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc index cf36a3d5..d1b48dd6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.cc @@ -15,6 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h" +#include #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -35,12 +36,12 @@ int BiasAddFP32Coder::DoCode(CoderContext *ctx) { std::string bias_str = allocator_->GetRuntimeAddr(input_tensors_.at(kWeightIndex), true); Collect(ctx, { - "nnacl/arithmetic_parameter.h", - "nnacl/nnacl_utils.h", - "nnacl/nnacl_common.h", - "nnacl/base/arithmetic_base.h", - "nnacl/fp32/add_fp32.h", - "nnacl/fp32/arithmetic_fp32.h", + "nnacl_c/arithmetic_parameter.h", + "nnacl_c/nnacl_utils.h", + "nnacl_c/nnacl_common.h", + "nnacl_c/base/arithmetic_base.h", + "nnacl_c/fp32/add_fp32.h", + "nnacl_c/fp32/arithmetic_fp32.h", }, { "arithmetic_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h index 62abec08..25c983db 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/biasadd_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/arithmetic_parameter.h" namespace mindspore::lite::micro::nnacl { class BiasAddFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc index 6419e11d..da17059d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.cc @@ -37,7 +37,7 @@ int ConcatFP32Coder::ReSize() { int ConcatFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/concat_base.h", + "nnacl_c/base/concat_base.h", }, { "concat_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h index 6f3f5c71..2cda19ff 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/concat_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/concat_parameter.h" +#include "nnacl_c/concat_parameter.h" namespace mindspore::lite::micro::nnacl { class ConcatFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc index 3aece46e..b4608f66 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.cc @@ -17,8 +17,8 @@ #include "coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h" #include "src/common/version_manager.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/base/conv_common_base.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/base/conv_common_base.h" #include "coder/opcoders/nnacl/fp32/convolution_fp32_coder.h" #include "coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h" #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h index 0cd3adb5..08678bcb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/conv2d_delegate_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class ConvDelegateCoder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc index 537fa14e..492af537 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc @@ -84,7 +84,7 @@ void ConvolutionDepthwiseFP32Coder::InitCodeOnline(CoderContext *const context) } Collect(context, { - "nnacl/fp32/pack_fp32.h", + "nnacl_c/fp32/pack_fp32.h", }, {"pack_fp32.c"}); NNaclFp32Serializer init_code; @@ -117,7 +117,7 @@ void ConvolutionDepthwiseFP32Coder::CollectFilesForFunc(CoderContext *const cont } Collect(context, { - "nnacl/fp32/conv_depthwise_fp32.h", + "nnacl_c/fp32/conv_depthwise_fp32.h", }, { "conv_depthwise_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc index cbdf2e71..6008a5f8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -145,10 +145,10 @@ void ConvolutionFP32Coder::CollectFilesForFunc(CoderContext *const context) { } Collect(context, { - "nnacl/fp32/conv_common_fp32.h", - "nnacl/fp32/matmul_fp32.h", - "nnacl/conv_parameter.h", - "nnacl/op_base.h", + "nnacl_c/fp32/conv_common_fp32.h", + "nnacl_c/fp32/matmul_fp32.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/op_base.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h index cf6ad614..94fbe059 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_fp32_coder.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc index c15d3101..39938f55 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" #include -#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl_c/base/minimal_filtering_generator.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" @@ -226,10 +226,10 @@ void ConvolutionWinogradFP32Coder::InitCodeOnline(CoderContext *const context) { } Collect(context, { - "nnacl/base/minimal_filtering_generator.h", - "nnacl/fp32/pack_fp32.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/fp32/pack_fp32.h", }, - {"minimal_filtering_generator.c", "nnacl/fp32/pack_fp32.h"}); + {"minimal_filtering_generator.c", "nnacl_c/fp32/pack_fp32.h"}); NNaclFp32Serializer init_code; init_code.CodeBufferOffsetExpression(trans_weight_, context->weight_name(), context->weight_offset_name(), context->weight_size_name(), trans_weight_size_); @@ -279,8 +279,8 @@ void ConvolutionWinogradFP32Coder::CollectFilesForFunc(CoderContext *const conte } Collect(context, { - "nnacl/fp32/conv_winograd_fp32.h", - "nnacl/common_func.h", + "nnacl_c/fp32/conv_winograd_fp32.h", + "nnacl_c/common_func.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h index c5c3a534..a3d6489f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc index ecbc6701..e5ba2f5a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" using mindspore::schema::PrimitiveType_Custom; @@ -127,7 +127,7 @@ int CustomGruFP32Coder::InitWeightAndBias() { } void CustomGruFP32Coder::InitNnaclFile(CoderContext *const context) { - Collect(context, {"nnacl/fp32/custom_gru_fp32.h"}, + Collect(context, {"nnacl_c/fp32/custom_gru_fp32.h"}, {"custom_gru_fp32.c", "pack_fp32.c", "matmul_fp32.c", "arithmetic_fp32.c", "activation_fp32.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h index f7ccbf31..ffb07140 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/custom_gru_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc index 83bf34d4..181bef9f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc @@ -18,7 +18,7 @@ #include #include #include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" -#include "nnacl/fp32/winograd_utils.h" +#include "nnacl_c/fp32/winograd_utils.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -148,15 +148,15 @@ void DeConvolutionFP32Coder::CollectFilesForFunc(CoderContext *const context) { Collect(context, { "wrapper/fp32/deconvolution_fp32_wrapper.h", - "nnacl/fp32/conv_common_fp32.h", - "nnacl/pack.h", - "nnacl/fp32/common_func_fp32.h", - "nnacl/base/minimal_filtering_generator.h", - "nnacl/fp32/matmul_fp32.h", - "nnacl/conv_parameter.h", - "nnacl/matmul_parameter.h", + "nnacl_c/fp32/conv_common_fp32.h", + "nnacl_c/pack.h", + "nnacl_c/fp32/common_func_fp32.h", + "nnacl_c/base/minimal_filtering_generator.h", + "nnacl_c/fp32/matmul_fp32.h", + "nnacl_c/conv_parameter.h", + "nnacl_c/matmul_parameter.h", "wrapper/base/micro_parameter.h", - "nnacl/op_base.h", + "nnacl_c/op_base.h", }, { "deconvolution_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h index b6901eb0..a01bc376 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.h @@ -19,11 +19,11 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/base/conv2d_base_coder.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" namespace mindspore::lite::micro::nnacl { class DeConvolutionFP32Coder : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc index 0cbc7ea1..3fb054b5 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc @@ -44,7 +44,7 @@ int ExpFP32Coder::Prepare(CoderContext *context) { int ExpFP32Coder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/fp32/exp_fp32.h", + "nnacl_c/fp32/exp_fp32.h", }, { "exp_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h index 2be5cc5b..20f4628f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/exp_fp32_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_EXP_FP32_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/exp_fp32.h" +#include "nnacl_c/fp32/exp_fp32.h" namespace mindspore::lite::micro::nnacl { class ExpFP32Coder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc index 32d89de9..beb21734 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc @@ -36,7 +36,7 @@ int FillFP32Coder::Prepare(CoderContext *context) { int FillFP32Coder::DoCode(CoderContext *ctx) { Collect(ctx, { - "nnacl/kernel/fill.h", + "nnacl_c/kernel/fill.h", }, { "fill.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h index ccd47e5c..c9f6cf23 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/fill_fp32_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_FILL_FP32_CODER_H_ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/fill.h" +#include "nnacl_c/kernel/fill.h" namespace mindspore::lite::micro::nnacl { class FillFP32Coder final : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc index 3cea0e35..a9066d29 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h" #include -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/utils/coder_utils.h" @@ -44,7 +44,7 @@ int GatherDynamicFP32Coder::Prepare(CoderContext *const context) { int GatherDynamicFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/gather_base.h", + "nnacl_c/base/gather_base.h", }, { "gather_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc index 1208badd..0040822c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/gather_fp32_coder.h" #include -#include "nnacl/gather_parameter.h" +#include "nnacl_c/gather_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" @@ -55,7 +55,7 @@ int GatherFP32Coder::DoCode(CoderContext *context) { // generate code .h .c Collect(context, { - "nnacl/base/gather_base.h", + "nnacl_c/base/gather_base.h", }, { "gather_base.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h index 6bf7ae6a..062e247d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class GatherFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc index b5415861..a66d3c45 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc @@ -16,8 +16,8 @@ #include "coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.h" #include #include -#include "nnacl/fp32/group_norm_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/group_norm_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -68,7 +68,7 @@ int GroupNormFP32Coder::DoCode(CoderContext *const context) { MS_CHECK_PTR(offset_tensor); Collect(context, { - "nnacl/fp32/group_norm_fp32.h", + "nnacl_c/fp32/group_norm_fp32.h", }, { "group_norm_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc index 367ab77a..cb93a003 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.cc @@ -40,7 +40,7 @@ int InstanceNormFP32Coder::Prepare(CoderContext *const context) { int InstanceNormFP32Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("instance_norm_param", *param_); - Collect(context, {"nnacl/fp32/pack_fp32.h", "nnacl/fp32/instance_norm_fp32.h"}, + Collect(context, {"nnacl_c/fp32/pack_fp32.h", "nnacl_c/fp32/instance_norm_fp32.h"}, {"pack_fp32.c", "instance_norm_fp32.c"}); if (input_tensors_[0]->format() == NHWC) { code.CodeFunction("PackNHWCToNC4HW4NotAlignedFp32", input_tensor_, tmp_src_data_, param_->batch_, diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h index 2fd42c4f..ad4fa813 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/instance_norm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/instance_norm_parameter.h" +#include "nnacl_c/instance_norm_parameter.h" namespace mindspore::lite::micro::nnacl { class InstanceNormFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc index e5d6b10e..f5f135a4 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc @@ -68,7 +68,7 @@ int LayerNormFP32Coder::Prepare(CoderContext *const context) { int LayerNormFP32Coder::DoCode(CoderContext *const context) { NNaclFp32Serializer code; code.CodeStruct("layer_norm_compute_parm", compute_); - Collect(context, {"nnacl/fp32/layer_norm_fp32.h"}, {"layer_norm_fp32.c"}); + Collect(context, {"nnacl_c/fp32/layer_norm_fp32.h"}, {"layer_norm_fp32.c"}); if (output_tensors_.size() == kOutputNum) { code.CodeFunction("LayerNorm", input_tensor_, input_tensors_.at(SECOND_INPUT), input_tensors_.at(THIRD_INPUT), output_tensor_, output_tensors_.at(SECOND_INPUT), output_tensors_.at(THIRD_INPUT), diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h index 4842d9cf..bde27e13 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/layer_norm_parameter.h" -#include "nnacl/kernel/layer_norm.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/kernel/layer_norm.h" namespace mindspore::lite::micro::nnacl { class LayerNormFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc index 30f95332..ab42dac9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc @@ -182,8 +182,8 @@ int LstmFP32Coder::Prepare(CoderContext *const context) { int LstmFP32Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/lstm_parameter.h", - "nnacl/fp32/lstm_fp32.h", + "nnacl_c/lstm_parameter.h", + "nnacl_c/fp32/lstm_fp32.h", }, { "lstm_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h index b54bf53a..eed3c3c1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/lstm_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/lstm_parameter.h" +#include "nnacl_c/lstm_parameter.h" namespace mindspore::lite::micro::nnacl { class LstmFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc index 84438cc8..0193dfa7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc @@ -21,7 +21,7 @@ #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "wrapper/fp32/matmul_fp32_wrapper.h" #include "coder/opcoders/nnacl/dequant/de_quant.h" @@ -160,8 +160,8 @@ int MatMulFP32BaseCoder::Prepare(CoderContext *const context) { return RET_OK; } int MatMulFP32BaseCoder::CollectFilesForTarget(CoderContext *const context) { Collect(context, { - "nnacl/fp32/pack_fp32.h", - "nnacl/fp32/matmul_fp32.h", + "nnacl_c/fp32/pack_fp32.h", + "nnacl_c/fp32/matmul_fp32.h", "wrapper/fp32/matmul_fp32_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h index 6e92508a..11a21413 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h index 046a2a1a..e05bae8c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulFP32Coder final : public MatMulFP32BaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc index 32909be1..5646ad25 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc @@ -29,7 +29,7 @@ int OnesLikeFP32Coder::Prepare(CoderContext *const context) { return RET_OK; } int OnesLikeFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/kernel/ones_like.h", + "nnacl_c/kernel/ones_like.h", }, { "ones_like.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc index e51b055f..be7212f1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc @@ -81,8 +81,8 @@ int PadFP32Coder::ExtendPaddings(int *paddings, int length, const int *ori_paddi int PadFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/pad_fp32.h", - "nnacl/pad_parameter.h", + "nnacl_c/fp32/pad_fp32.h", + "nnacl_c/pad_parameter.h", }, { "pad_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h index 3cf3ff31..30446a87 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pad_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/fp32/pad_fp32.h" +#include "nnacl_c/fp32/pad_fp32.h" namespace mindspore::lite::micro::nnacl { class PadFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc index b02dc336..25bdc495 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32/pooling_fp32_coder.h" #include #include -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -49,8 +49,8 @@ int PoolingFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/pooling_fp32_wrapper.h", - "nnacl/kernel/pooling.h", - "nnacl/fp32/pooling_fp32.h", + "nnacl_c/kernel/pooling.h", + "nnacl_c/fp32/pooling_fp32.h", }, { "pooling_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h index 37bebc00..79f8844d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/pooling_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc index 462d1a45..edc05ea0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.cc @@ -51,8 +51,8 @@ int PowerFP32Coder::DoCode(CoderContext *const context) { // generate code .h .c Collect(context, { - "nnacl/pow_parameter.h", - "nnacl/fp32/power_fp32.h", + "nnacl_c/pow_parameter.h", + "nnacl_c/fp32/power_fp32.h", }, { "power_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h index 84fb3a39..e99ff594 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/power_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/pow_parameter.h" +#include "nnacl_c/pow_parameter.h" namespace mindspore::lite::micro::nnacl { class PowerFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc index defb669a..68d7430d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc @@ -15,8 +15,8 @@ */ #include "coder/opcoders/nnacl/fp32/prelu_fp32_coder.h" #include -#include "nnacl/fp32/prelu_fp32.h" -#include "nnacl/op_base.h" +#include "nnacl_c/fp32/prelu_fp32.h" +#include "nnacl_c/op_base.h" #include "coder/allocator/allocator.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -29,7 +29,7 @@ int PReluFP32Coder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp32/prelu_fp32.h", + "nnacl_c/fp32/prelu_fp32.h", }, { "prelu_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc index 2970309a..ed6127c2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/reduce_fp32_coder.cc @@ -50,7 +50,7 @@ int ReduceFP32Coder::ReSize() { int ReduceFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/reduce_fp32.h", + "nnacl_c/fp32/reduce_fp32.h", }, { "reduce_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc index d84d0c60..fdffada8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.cc @@ -159,7 +159,7 @@ int ResizeFP32Coder::ResizePrepare() { int ResizeFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/resize_fp32.h", + "nnacl_c/fp32/resize_fp32.h", }, { "resize_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h index 6654df2c..9375ed99 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/resize_fp32_coder.h @@ -22,7 +22,7 @@ #include #include #include "include/errorcode.h" -#include "nnacl/fp32/resize_fp32.h" +#include "nnacl_c/fp32/resize_fp32.h" #include "src/executor/kernel_exec.h" #include "src/litert/kernel/cpu/fp32/resize_fp32.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc index d4b3ca14..e88d4912 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.cc @@ -93,9 +93,9 @@ int ScaleFP32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/scale_fp32_wrapper.h", - "nnacl/scale_parameter.h", - "nnacl/kernel/scale.h", - "nnacl/fp32/scale_fp32.h", + "nnacl_c/scale_parameter.h", + "nnacl_c/kernel/scale.h", + "nnacl_c/fp32/scale_fp32.h", }, { "scale_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h index 9a764e2e..2dd048c1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/scale_fp32_coder.h @@ -19,8 +19,8 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/kernel/scale.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/kernel/scale.h" namespace mindspore::lite::micro::nnacl { class ScaleFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc index 8e8d88a6..2566297a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.cc @@ -17,8 +17,8 @@ #include "tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h" #include "tools/converter/micro/coder/opcoders/file_collector.h" #include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/base/slice_base.h" #include "coder/opcoders/parallel.h" using mindspore::schema::PrimitiveType_SliceFusion; @@ -73,7 +73,7 @@ int SliceFP32Coder::Prepare(CoderContext *const context) { int SliceFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/base/slice_base.h", + "nnacl_c/base/slice_base.h", "wrapper/fp32/slice_fp32_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h index 21384323..37b999a2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/slice_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "tools/converter/micro/coder/opcoders/op_coder.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/kernel/slice.h" namespace mindspore::lite::micro::nnacl { class SliceFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc index 1bf18c06..e7075b55 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/softmax_fp32_coder.cc @@ -36,8 +36,8 @@ int SoftMaxFP32Coder::Prepare(CoderContext *const context) { int SoftMaxFP32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/fp32/softmax_fp32.h", - "nnacl/fp32/log_softmax_fp32.h", + "nnacl_c/fp32/softmax_fp32.h", + "nnacl_c/fp32/log_softmax_fp32.h", }, { "softmax_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc index 4460f914..1529179d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" #include "src/common/log_adapter.h" -#include "nnacl/splice_parameter.h" +#include "nnacl_c/splice_parameter.h" using mindspore::schema::PrimitiveType_Splice; namespace mindspore::lite::micro::nnacl { int SpliceFP32Coder::DoCode(CoderContext *const context) { @@ -42,8 +42,8 @@ int SpliceFP32Coder::DoCode(CoderContext *const context) { } Collect(context, { - "nnacl/splice_parameter.h", - "nnacl/fp32/splice_fp32.h", + "nnacl_c/splice_parameter.h", + "nnacl_c/fp32/splice_fp32.h", }, { "splice_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc index c8778a7d..6392fcd6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/utils/coder_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" using mindspore::schema::PrimitiveType_Split; @@ -64,7 +64,7 @@ int SplitDynamicFP32Coder::Prepare(CoderContext *const context) { } int SplitDynamicFP32Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); + Collect(context, {"nnacl_c/base/split_base.h"}, {"split_base.c"}); NNaclFp32Serializer code; code << " void *output_ptrs[" << output_tensors_.size() << "] = {"; for (int i = 0; i < param_->num_split_; i++) { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h index 88ca2bfe..efc253c0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::micro::nnacl { class SplitDynamicFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc index 43318576..88a30e33 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.cc @@ -33,7 +33,7 @@ int SplitFP32Coder::Prepare(CoderContext *const context) { } int SplitFP32Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); + Collect(context, {"nnacl_c/base/split_base.h"}, {"split_base.c"}); if (support_parallel_) { Collect(context, {"wrapper/fp32/split_fp32_wrapper.h"}, {"split_fp32_wrapper.c"}); } diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h index f65214c1..32d75f2d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/split_parameter.h" +#include "nnacl_c/split_parameter.h" namespace mindspore::lite::micro::nnacl { class SplitFP32Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc index 5a663f1b..da8f1e11 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.cc @@ -47,10 +47,10 @@ int TileFP32Coder::DoCode(CoderContext *const context) { // generate code .h .c Collect(context, { - "nnacl/fp32/tile.h", + "nnacl_c/fp32/tile.h", }, { - "nnacl/fp32/tile.c", + "nnacl_c/fp32/tile.c", }); NNaclFp32Serializer code; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h index c0627cc5..294e3f1b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/tile_fp32_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/base/tile_base.h" +#include "nnacl_c/base/tile_base.h" namespace mindspore::lite::micro::nnacl { class TileFP32Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc index 7fb160d5..1f2bade9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc @@ -42,9 +42,9 @@ int TransposeDynamicFp32Coder::Prepare(CoderContext *const context) { int TransposeDynamicFp32Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp32/transpose_fp32.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp32/transpose_fp32.h", }, { "transpose_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h index 9230b8e3..f956c281 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc index 97b46775..4bacbb20 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc @@ -107,9 +107,9 @@ int TransposeFp32Coder::DoCode(CoderContext *const context) { Collect(context, { "wrapper/fp32/transpose_fp32_wrapper.h", - "nnacl/transpose_parameter.h", - "nnacl/errorcode.h", - "nnacl/fp32/transpose_fp32.h", + "nnacl_c/transpose_parameter.h", + "nnacl_c/errorcode.h", + "nnacl_c/fp32/transpose_fp32.h", }, { "transpose_fp32_wrapper.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h index 1b81dc26..737516d7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_fp32_coder.h @@ -19,7 +19,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeFp32Coder : public OperatorCoder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc index b9d940d3..4b867f11 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32_grad/activation_grad_coder.h" -#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl_c/fp32_grad/activation_grad_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -31,7 +31,7 @@ int ActivationGradCoder::DoCode(CoderContext *const context) { int count = input_tensor_->ElementsNum(); Collect(context, { - "nnacl/fp32_grad/activation_grad_fp32.h", + "nnacl_c/fp32_grad/activation_grad_fp32.h", }, { "activation_grad_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc index c4427cf9..04821d8c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/adam_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/fp32_grad/adam_coder.h" -#include "nnacl/fp32_grad/optimizer.h" +#include "nnacl_c/fp32_grad/optimizer.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" #include "coder/opcoders/file_collector.h" @@ -53,7 +53,7 @@ int AdamCoder::DoCode(CoderContext *const context) { auto *adam_param = reinterpret_cast(parameter_); Collect(context, { - "nnacl/fp32/adam_fp32.h", + "nnacl_c/fp32/adam_fp32.h", }, { "adam_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc index c194f0ca..4b743fcd 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h" #include -#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h" +#include "nnacl_c/fp32_grad/softmax_crossentropy_parameter.h" #include "coder/opcoders/file_collector.h" #include "schema/inner/ops_generated.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" @@ -49,8 +49,8 @@ int SoftmaxCrossEntropyWithLogitsCoder::DoCode(CoderContext *const context) { MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_2D, "inputs size is not equal to two"); Collect(context, { - "nnacl/fp32/softmax_fp32.h", - "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h", + "nnacl_c/fp32/softmax_fp32.h", + "nnacl_c/fp32_grad/softmax_cross_entropy_with_logits.h", }, { "softmax_fp32.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h index 3aea3e4d..6161589b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CODER_H_ #include -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/opcoders/op_coder.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc index 762dcd49..38cd39b2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/activation_int8_coder.cc @@ -18,7 +18,7 @@ #include "coder/opcoders/nnacl/int8/relux_int8_coder.h" #include "coder/opcoders/nnacl/int8/tanh_int8_coder.h" #include "src/common/ops/populate/populate_register.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "schema/model_generated.h" #include "src/common/version_manager.h" #include "coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc index d9732ca6..266a2f32 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.cc @@ -23,7 +23,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/utils/common.h" #include "mindspore/ops/op_def/array_ops.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" using mindspore::schema::PrimitiveType_AddFusion; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h index 06809bd8..1b8a832e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/add_int8_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/add_int8.h" +#include "nnacl_c/int8/add_int8.h" namespace mindspore::lite::micro::nnacl { class AddInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h index 97acf295..09eec680 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/affine_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/affine_parameter.h" +#include "nnacl_c/affine_parameter.h" #include "tools/converter/micro/coder/wrapper/base/affine_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc index bf8a074d..b31fc592 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.cc @@ -114,7 +114,7 @@ int ArithmeticSelfInt8Coder::Prepare(CoderContext *context) { int ArithmeticSelfInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/arithmetic_self_int8.h", + "nnacl_c/int8/arithmetic_self_int8.h", }, { "arithmetic_self_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h index 5dcf1373..3f7695ff 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/arithmetic_self_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/arithmetic_self_int8.h" -#include "nnacl/arithmetic_self_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/arithmetic_self_int8.h" +#include "nnacl_c/arithmetic_self_parameter.h" namespace mindspore::lite::micro::nnacl { class ArithmeticSelfInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc index 59a3df56..9616f469 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.cc @@ -54,8 +54,8 @@ int BatchNormInt8Coder::DoCode(CoderContext *context) { Collect(context, { - "nnacl/slice_parameter.h", - "nnacl/kernel/batch_norm.h", + "nnacl_c/slice_parameter.h", + "nnacl_c/kernel/batch_norm.h", }, { "batchnorm_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h index 6cba41bc..a38e7fc7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/batchnorm_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" namespace mindspore::lite::micro::nnacl { class BatchNormInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc index 5aea116d..41b2eb5d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.cc @@ -16,9 +16,9 @@ #include "coder/opcoders/nnacl/int8/concat_int8_coder.h" #include -#include "nnacl/int8/concat_int8.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/concat_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -100,7 +100,7 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/concat_int8.h", + "nnacl_c/int8/concat_int8.h", "wrapper/int8/concat_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h index 46a8bdb6..25c80fd7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/concat_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/int8/concat_int8.h" #include "wrapper/int8/concat_int8_wrapper.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc index 14727188..a8d11f51 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.cc @@ -48,12 +48,12 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) { "wrapper/int8/conv1x1_init_int8_wrapper.h", "wrapper/int8/conv1x1_run_int8_wrapper.h", "wrapper/base/micro_parameter.h", - "nnacl/common_func.h", - "nnacl/base/conv1x1_base.h", - "nnacl/int8/matmul_int8.h", - "nnacl/int8/pack_int8.h", - "nnacl/int8/conv1x1_int8.h", - "nnacl/errorcode.h", + "nnacl_c/common_func.h", + "nnacl_c/base/conv1x1_base.h", + "nnacl_c/int8/matmul_int8.h", + "nnacl_c/int8/pack_int8.h", + "nnacl_c/int8/conv1x1_int8.h", + "nnacl_c/errorcode.h", }, { "common_func.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h index eff155ff..97b6843e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_1x1_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "wrapper/base/micro_parameter.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc index 9ebddef3..8e8b594d 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc @@ -17,7 +17,7 @@ #include "coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h" #include #include "include/securec.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/int8/conv3x3_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" #include "coder/opcoders/parallel.h" @@ -128,8 +128,8 @@ int Conv2D3x3Int8Coder::Prepare(CoderContext *const context) { int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/conv_int8.h", - "nnacl/int8/conv3x3_int8.h", + "nnacl_c/int8/conv_int8.h", + "nnacl_c/int8/conv3x3_int8.h", }, { "pack_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h index bd96e3e2..d521ff1f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" namespace mindspore::lite::micro::nnacl { class Conv2D3x3Int8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc index a5195199..99479ebe 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.cc @@ -214,8 +214,8 @@ int Conv2DINT8Coder::DoCode(CoderContext *const context) { } Collect(context, { - "nnacl/int8/conv_int8.h", - "nnacl/common_func.h", + "nnacl_c/int8/conv_int8.h", + "nnacl_c/common_func.h", "wrapper/int8/convolution_int8_wrapper.h", "wrapper/base/common_wrapper.h", "wrapper/base/optimize_handler_wrapper.h", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h index ddfc6913..bab8a56a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc index 96a3e1e4..7b8b71b2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/convolution_depthwise_int8_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/file_collector.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" namespace mindspore::lite::micro { int ConvolutionDepthwiseINT8Coder::Prepare(CoderContext *const context) { @@ -90,8 +90,8 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) { "Only support input channel equals output channel."); Collect(context, { - "nnacl/int8/conv_depthwise_int8.h", - "nnacl/int8/pack_int8.h", + "nnacl_c/int8/conv_depthwise_int8.h", + "nnacl_c/int8/pack_int8.h", "wrapper/int8/convolution_depthwise_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc index 27fdc3f5..0dfdc993 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/int8/deconvolution_int8_coder.h" #include -#include "nnacl/int8/deconv_int8.h" +#include "nnacl_c/int8/deconv_int8.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" @@ -125,7 +125,7 @@ int DeconvolutionInt8Coder::InitRunBuf(CoderContext *const context) { int DeconvolutionInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/deconv_int8.h", + "nnacl_c/int8/deconv_int8.h", }, { "deconv_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h index c8ae3ccb..0acd4f17 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/deconvolution_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/conv2d_base_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class DeconvolutionInt8Coder final : public Conv2DBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc index 00783220..23f3d45b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/detection_post_process_int8_coder.cc @@ -45,7 +45,7 @@ int DetectionPostProcessInt8Coder::GetInputData(CoderContext *const context, Ser Collect(context, { - "nnacl/int8/quant_dtype_cast_int8.h", + "nnacl_c/int8/quant_dtype_cast_int8.h", }, { "quant_dtype_cast_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc index e6b86e7c..83a967bf 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.cc @@ -53,7 +53,7 @@ int DivInt8Coder::Prepare(CoderContext *context) { int DivInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/div_int8.h", + "nnacl_c/int8/div_int8.h", }, { "div_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h index 6cb9cb91..4a8859e0 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/div_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::nnacl { class DivInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc index fdefd4a5..c45d3cea 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/int8/fullconnection_int8_coder.h" -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #include "coder/log.h" using mindspore::schema::PrimitiveType_FullConnection; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h index c5946a0c..642f520b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/fullconnection_int8_coder.h @@ -21,8 +21,8 @@ #include #include #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class FullConnectionInt8Coder final : public MatMulBaseInt8Coder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc index 59491e61..ed315947 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.cc @@ -50,7 +50,7 @@ int GatherInt8Coder::Prepare(CoderContext *context) { int GatherInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/gather_int8.h", + "nnacl_c/int8/gather_int8.h", }, { "gather_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h index ee4f4c35..f3143efd 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/gather_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/gather_int8.h" -#include "nnacl/gather_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/gather_int8.h" +#include "nnacl_c/gather_parameter.h" namespace mindspore::lite::micro::nnacl { class GatherInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc index f04a6fd5..fbfbc466 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.cc @@ -50,7 +50,7 @@ int LeakyReluInt8Coder::Prepare(CoderContext *context) { int LeakyReluInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/leaky_relu_int8.h", + "nnacl_c/int8/leaky_relu_int8.h", }, { "leaky_relu_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h index f76ec591..bd5fadf9 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/leaky_relu_int8.h" -#include "nnacl/activation_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/leaky_relu_int8.h" +#include "nnacl_c/activation_parameter.h" namespace mindspore::lite::micro::nnacl { class LeakyReluInt8Coder : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc index 6cf6c17f..aee53e95 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -274,11 +274,11 @@ void MatMulBaseInt8Coder::DoBatchCode(NNaclInt8Serializer *code_ptr) { int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/common_func.h", - "nnacl/int8/common_func_int8.h", - "nnacl/int8/matmul_int8.h", - "nnacl/int8/fixed_point.h", - "nnacl/int8/relux_int8.h", + "nnacl_c/common_func.h", + "nnacl_c/int8/common_func_int8.h", + "nnacl_c/int8/matmul_int8.h", + "nnacl_c/int8/fixed_point.h", + "nnacl_c/int8/relux_int8.h", "wrapper/int8/matmul_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h index 28a9f271..10c5d2cb 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.h @@ -19,7 +19,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h index bec68df5..5c8548b1 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_int8_coder.h @@ -20,7 +20,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/opcoders/nnacl/int8/matmul_base_int8_coder.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore::lite::micro::nnacl { class MatMulInt8Coder final : public MatMulBaseInt8Coder { public: diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc index 2474f241..37343064 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.cc @@ -205,7 +205,7 @@ int PadInt8Coder::HandleMirrorPad() { int PadInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/pad_int8.h", + "nnacl_c/int8/pad_int8.h", }, { "pad_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h index 3cd0b85a..8a3305fa 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pad_int8_coder.h @@ -20,9 +20,9 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/pad_int8.h" -#include "nnacl/pad_parameter.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/pad_int8.h" +#include "nnacl_c/pad_parameter.h" namespace mindspore::lite::micro::nnacl { class PadInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc index 1e639308..75c7491c 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.cc @@ -16,7 +16,7 @@ #include "coder/opcoders/nnacl/int8/pooling_int8_coder.h" #include #include -#include "nnacl/int8/pooling_int8.h" +#include "nnacl_c/int8/pooling_int8.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" @@ -54,9 +54,9 @@ int PoolingInt8Coder::DoCode(CoderContext *const context) { std::vector out_quant_args = out_tensor->quant_params(); Collect(context, { - "nnacl/int8/pooling_int8.h", - "nnacl/kernel/pooling.h", - "nnacl/errorcode.h", + "nnacl_c/int8/pooling_int8.h", + "nnacl_c/kernel/pooling.h", + "nnacl_c/errorcode.h", }, { "pooling_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h index d255f77d..9400bde3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/pooling_int8_coder.h @@ -21,7 +21,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/kernel/pooling.h" +#include "nnacl_c/kernel/pooling.h" namespace mindspore::lite::micro::nnacl { class PoolingInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h index aacfe99d..c5281c57 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/prelu_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/nnacl/int8/leaky_relu_int8_coder.h" namespace mindspore::lite::micro::nnacl { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc index 6e77ded3..68d54982 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.cc @@ -235,8 +235,8 @@ int ReduceInt8Coder::DoCode(CoderContext *const context) { if (axes_hw_pattern_) { Collect(context, { - "nnacl/int8/pack_int8.h", - "nnacl/int8/reduce_int8.h", + "nnacl_c/int8/pack_int8.h", + "nnacl_c/int8/reduce_int8.h", }, { "pack_int8.c", @@ -256,7 +256,7 @@ int ReduceInt8Coder::DoCode(CoderContext *const context) { } else { Collect(context, { - "nnacl/int8/reduce_int8.h", + "nnacl_c/int8/reduce_int8.h", }, { "reduce_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h index 1fbf1efc..03bc752a 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reduce_int8_coder.h @@ -20,8 +20,8 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/int8/reduce_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/int8/reduce_int8.h" #include "coder/opcoders/base/reduce_base_coder.h" namespace mindspore::lite::micro::nnacl { class ReduceInt8Coder final : public ReduceBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc index 5ddf04b1..0f67e4a8 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.cc @@ -15,7 +15,7 @@ */ #include "coder/opcoders/nnacl/int8/relux_int8_coder.h" -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" #include "coder/log.h" @@ -41,7 +41,7 @@ int ReluxInt8Coder::Prepare(CoderContext *const context) { int ReluxInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/relux_int8.h", + "nnacl_c/int8/relux_int8.h", }, { "relux_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h index aded1729..79ce4146 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/relux_int8_coder.h @@ -21,7 +21,7 @@ #include #include "coder/opcoders/op_coder.h" #include "coder/utils/common.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/int8/relux_int8.h" #include "coder/log.h" #include "include/errorcode.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc index 1533307b..2b2fb9d2 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/reshape_int8_coder.cc @@ -35,7 +35,7 @@ int ReshapeInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/reshape_int8.h", + "nnacl_c/int8/reshape_int8.h", }, { "reshape_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc index f2fea6dc..60843532 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.cc @@ -19,7 +19,7 @@ #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" #include "include/securec.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/parallel.h" using mindspore::schema::PrimitiveType_Resize; @@ -67,7 +67,7 @@ int ResizeInt8Coder::ReSize() { int ResizeInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/resize_int8.h", + "nnacl_c/int8/resize_int8.h", "wrapper/int8/resize_int8_wrapper.h", }, { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h index 9b7e6a84..3300de10 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/resize_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/base/resize_base_coder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::micro::nnacl { class ResizeInt8Coder final : public ResizeBaseCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc index 6f29562a..eea8127f 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sigmoid_int8_coder.cc @@ -55,7 +55,7 @@ int SigmodInt8Coder::Prepare(CoderContext *const context) { int SigmodInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/sigmoid_int8.h", + "nnacl_c/int8/sigmoid_int8.h", }, { "sigmoid_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc index 452a452e..492e1d81 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/softmax_int8_coder.cc @@ -19,7 +19,7 @@ #include #include #include "schema/inner/ops_generated.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" #include "coder/opcoders/file_collector.h" @@ -72,7 +72,7 @@ int SoftMaxInt8Coder::DoCode(CoderContext *const context) { "n_dim should be less than the length of maximum value of input_shape"); Collect(context, { - "nnacl/int8/softmax_int8.h", + "nnacl_c/int8/softmax_int8.h", }, { "softmax_int8.c", diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc index e08f4549..cbf0e03e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.cc @@ -16,6 +16,7 @@ #include "coder/opcoders/nnacl/int8/sub_int8_coder.h" #include +#include #include "include/errorcode.h" #include "coder/log.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" @@ -73,7 +74,7 @@ int SubInt8Coder::Prepare(CoderContext *const context) { } int SubInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/int8/arithmetic_int8.h", "nnacl/int8/sub_int8.h"}, {"arithmetic_int8.c", "sub_int8.c"}); + Collect(context, {"nnacl_c/int8/arithmetic_int8.h", "nnacl_c/int8/sub_int8.h"}, {"arithmetic_int8.c", "sub_int8.c"}); NNaclInt8Serializer code; // Todo: Parallel run wrapper auto element_num = output_tensor_->ElementsNum(); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h index a616c143..79aa9ae3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/sub_int8_coder.h @@ -20,7 +20,7 @@ #include #include #include "coder/opcoders/op_coder.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro::nnacl { class SubInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc index 6704d9ee..23f84516 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/tanh_int8_coder.cc @@ -21,7 +21,7 @@ #include "include/errorcode.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h" -#include "nnacl/int8/tanh_int8.h" +#include "nnacl_c/int8/tanh_int8.h" namespace mindspore::lite::micro::nnacl { int TanhInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } @@ -29,7 +29,7 @@ int TanhInt8Coder::Prepare(CoderContext *const context) { return RET_OK; } int TanhInt8Coder::DoCode(CoderContext *const context) { Collect(context, { - "nnacl/int8/tanh_int8.h", + "nnacl_c/int8/tanh_int8.h", }, {"tanh_int8.c", "activation_fp32.c"}); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc index 43a19abc..9fa23bb7 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.cc @@ -49,7 +49,7 @@ int TransposeInt8Coder::Prepare(CoderContext *const context) { } int TransposeInt8Coder::DoCode(CoderContext *const context) { - Collect(context, {"nnacl/int8/pack_int8.h", "nnacl/int8/transpose_int8.h"}, {"pack_int8.c", "transpose_int8.c"}); + Collect(context, {"nnacl_c/int8/pack_int8.h", "nnacl_c/int8/transpose_int8.h"}, {"pack_int8.c", "transpose_int8.c"}); NNaclInt8Serializer code; auto out_shape = output_tensors_[0]->shape(); diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h index 12e06f86..0c49a6f3 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/transpose_int8_coder.h @@ -18,7 +18,7 @@ #include #include "coder/opcoders/op_coder.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/transpose_parameter.h" namespace mindspore::lite::micro::nnacl { class TransposeInt8Coder final : public OperatorCoder { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h index 4d27c700..8201f56e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h @@ -20,43 +20,43 @@ #include #include #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/fp32/arithmetic_fp32.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" -#include "nnacl/scale_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/split_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/base/tile_base.h" -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/splice_parameter.h" -#include "nnacl/lstm_parameter.h" -#include "nnacl/group_norm_parameter.h" -#include "nnacl/activation_parameter.h" +#include "nnacl_c/scale_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/split_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/base/tile_base.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/splice_parameter.h" +#include "nnacl_c/lstm_parameter.h" +#include "nnacl_c/group_norm_parameter.h" +#include "nnacl_c/activation_parameter.h" #include "wrapper/fp32/dequant_int8_to_fp32_wrapper.h" -#include "nnacl/fp32/exp_fp32.h" -#include "nnacl/fp32/strided_slice_fp32.h" -#include "nnacl/tensor_c.h" +#include "nnacl_c/fp32/exp_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" +#include "nnacl_c/tensor_c.h" #include "wrapper/fp32/arithmetic_fp32_wrapper.h" #include "wrapper/base/affine_wrapper.h" #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" -#include "nnacl/instance_norm_parameter.h" -#include "nnacl/layer_norm_parameter.h" -#include "nnacl/broadcast_to_parameter.h" -#include "nnacl/custom_gru_parameter.h" -#include "nnacl/unstack_parameter.h" -#include "nnacl/kernel/scale.h" -#include "nnacl/kernel/pooling.h" -#include "nnacl/kernel/layer_norm.h" -#include "nnacl/kernel/fill.h" -#include "nnacl/kernel/batch_norm.h" -#include "nnacl/kernel/tile.h" -#include "nnacl/kernel/slice.h" -#include "nnacl/kernel/strided_slice.h" +#include "nnacl_c/instance_norm_parameter.h" +#include "nnacl_c/layer_norm_parameter.h" +#include "nnacl_c/broadcast_to_parameter.h" +#include "nnacl_c/custom_gru_parameter.h" +#include "nnacl_c/unstack_parameter.h" +#include "nnacl_c/kernel/scale.h" +#include "nnacl_c/kernel/pooling.h" +#include "nnacl_c/kernel/layer_norm.h" +#include "nnacl_c/kernel/fill.h" +#include "nnacl_c/kernel/batch_norm.h" +#include "nnacl_c/kernel/tile.h" +#include "nnacl_c/kernel/slice.h" +#include "nnacl_c/kernel/strided_slice.h" #include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" #include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h index 245c1141..dd016f6e 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.h @@ -18,26 +18,26 @@ #include #include #include "wrapper/base/affine_wrapper.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/softmax_parameter.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/softmax_parameter.h" #include "coder/opcoders/serializers/serializer.h" -#include "nnacl/op_base.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/arithmetic_int8.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/arithmetic_int8.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" -#include "nnacl/int8/concat_int8.h" -#include "nnacl/int8/quantize.h" -#include "nnacl/reshape_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/batchnorm_parameter.h" -#include "nnacl/pad_parameter.h" -#include "nnacl/transpose_parameter.h" -#include "nnacl/int8/relux_int8.h" +#include "nnacl_c/int8/concat_int8.h" +#include "nnacl_c/int8/quantize.h" +#include "nnacl_c/reshape_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" +#include "nnacl_c/pad_parameter.h" +#include "nnacl_c/transpose_parameter.h" +#include "nnacl_c/int8/relux_int8.h" #include "wrapper/int8/concat_int8_wrapper.h" -#include "nnacl/kernel/pooling.h" -#include "nnacl/kernel/batch_norm.h" +#include "nnacl_c/kernel/pooling.h" +#include "nnacl_c/kernel/batch_norm.h" namespace mindspore::lite::micro::nnacl { class NNaclInt8Serializer : public Serializer { diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc index b4e2e2ed..c7050592 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.cc @@ -16,11 +16,11 @@ #include #include -#include "nnacl/pooling_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/quantize.h" #include "coder/opcoders/parallel.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h" diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h index 6a219bec..ae06d287 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_stream_utils.h @@ -18,12 +18,12 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_SERIALIZERS_NNACL_SERIALIZER_NNACL_STREAM_UTILS_H_ #include #include -#include "nnacl/op_base.h" -#include "nnacl/pooling_parameter.h" -#include "nnacl/slice_parameter.h" -#include "nnacl/softmax_parameter.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/int8/quantize.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/pooling_parameter.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/softmax_parameter.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/int8/quantize.h" namespace mindspore::lite::micro { std::ostream &operator<<(std::ostream &code, const ::QuantArg &quant_arg); diff --git a/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h b/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h index d8c03ab6..4c6b1a0d 100644 --- a/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h +++ b/mindspore-lite/tools/converter/micro/coder/utils/type_cast.h @@ -25,7 +25,7 @@ #include "ir/dtype/type_id.h" #include "include/api/format.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/micro/coder/config.h" #include "base/float16.h" diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h index c0a3e6e6..1cad479a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/affine_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_AFFINE_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_AFFINE_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h index c1c6e9aa..73349cee 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/common_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" bool GetSupportOptFlag(); #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_COMMON_WRAPPER_H_ diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h index 3763bcdf..e8bf7eec 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/micro_parameter.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_MICRO_PARAMETER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_MICRO_PARAMETER_H_ -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" typedef struct { ActType act_type_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h index 8d3000f0..8835550a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/optimize_handler_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_OPTIMIZE_HANDLER_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_BASE_OPTIMIZE_HANDLER_WRAPPER_H_ -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #ifdef ENABLE_ARM64 void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c index 2a3f1715..60a54327 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/base/strided_slice_wrapper.h" -#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl_c/fp32/strided_slice_fp32.h" int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { StridedSliceStruct strided_slice; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h index 8beb4051..b9d7ede0 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/base/strided_slice_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_BASE_STRIDED_SLICE_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_BASE_STRIDED_SLICE_WRAPPER_H_ #include -#include "nnacl/strided_slice_parameter.h" +#include "nnacl_c/strided_slice_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c index 07e0f103..4dec9396 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/activation_fp32_wrapper.h" -#include "nnacl/fp32/activation_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/activation_fp32.h" +#include "nnacl_c/errorcode.h" int DoSigmoid(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ActivationFp32Args *args = (ActivationFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h index 64a13b0e..d4eb98c0 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/activation_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_ACTIVATION_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_ACTIVATION_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/activation_fp32.h" +#include "nnacl_c/fp32/activation_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h index 7c4d3001..3e58197e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/arithmetic_fp32_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_ARITHMETIC_FP32_WRAPPER_H_ -#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl_c/fp32/arithmetic_fp32.h" #include #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c index 144a1e93..6052ffca 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/concat_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/concat_fp32_wrapper.h" -#include "nnacl/errorcode.h" -#include "nnacl/base/concat_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/base/concat_base.h" int DoConcatRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConcatFp32Args *args = (ConcatFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c index e1dabe5a..7f2810c6 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.c @@ -16,8 +16,8 @@ #include "wrapper/fp32/conv_fp32_wrapper.h" #include -#include "nnacl/errorcode.h" -#include "nnacl/fp32/conv_common_fp32.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/fp32/conv_common_fp32.h" int ConvFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConvFp32Args *args = (ConvFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h index 4ee4ecb9..96aab4e7 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_fp32_wrapper.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_WRAPPER_FP32_CONV_FP32_WRAPPER_H_ -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c index 69a431e5..9153ee3d 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/conv_winograd_fp32_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int ConvWinogradFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ConvWinogradFp32Args *args = (ConvWinogradFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h index fe91a288..f71bc645 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/conv_winograd_fp32_wrapper.h @@ -15,8 +15,8 @@ */ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_CONV_WINOGRAD_FP32_WRAPPER_H_ -#include "nnacl/fp32/winograd_utils.h" -#include "nnacl/fp32/conv_winograd_fp32.h" +#include "nnacl_c/fp32/winograd_utils.h" +#include "nnacl_c/fp32/conv_winograd_fp32.h" #ifdef __cplusplus #include typedef struct TransFuncStr { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c index 917effc8..4a438edb 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/deconvolution_fp32_wrapper.h" -#include "nnacl/fp32/deconv_fp32.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/deconv_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" int DoDeconvFp32(const float *packed_input, const float *packed_weight, const float *packed_bias, float *packed_output, float *output, float *tmp_ori_buffer, const MicroMatmulParameter *matmul_param, diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h index 64d18ef2..78bade2a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/deconvolution_fp32_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_DECONVOLUTION_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_DECONVOLUTION_FP32_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" typedef struct { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c index e102fe0c..bb7827bd 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/fill_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/fill_fp32_wrapper.h" -#include "nnacl/errorcode.h" -#include "nnacl/base/fill_base.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/base/fill_base.h" int DoFillFp32(void *cdata, int task_id, float lhs_scale, float rhs_scale) { FillFp32Args *args = (FillFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c index a0a2dd95..75551a98 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/matmul_fp32_wrapper.h" -#include "nnacl/fp32/pack_fp32.h" +#include "nnacl_c/fp32/pack_fp32.h" void InitMatrixA(const float *src_ptr, float *dst_ptr, const MicroMatmulParameter *params_, bool is_vector_a) { if (is_vector_a) { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h index 4303a6ce..c3b34d57 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/matmul_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_MATMUL_FP32_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_MATMUL_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "wrapper/base/micro_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c index e6253af1..61705faf 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/pooling_fp32_wrapper.h" -#include "nnacl/fp32/pooling_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/pooling_fp32.h" +#include "nnacl_c/errorcode.h" int DoMaxPooling(void *cdata, int task_id, float lhs_scale, float rhs_scale) { PoolingFp32Args *args = (PoolingFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h index 34062d3c..72326f74 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/pooling_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_POOLING_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_POOLING_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl_c/fp32/pooling_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c index 69ff70fc..2ffd06ee 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/fp32/scale_fp32_wrapper.h" -#include "nnacl/fp32/scale_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/scale_fp32.h" +#include "nnacl_c/errorcode.h" int DoScaleReluRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ScaleFp32Args *args = (ScaleFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h index 7fd71387..5bc68886 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/scale_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SCALE_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SCALE_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/scale_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" typedef struct { const float *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c index 6227fbd0..68bb194e 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/slice_fp32_wrapper.h" -#include "nnacl/base/slice_base.h" +#include "nnacl_c/base/slice_base.h" int DoSliceRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SliceFp32Args *args = (SliceFp32Args *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h index 33aa514d..f1a829c7 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/slice_fp32_wrapper.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_FP32_SLICE_FP32_WRAPPER_H_ #include -#include "nnacl/slice_parameter.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" typedef struct { float *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c index 8e42305a..dae81be5 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/fp32/split_fp32_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int DoSplitRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SplitFp32Args *args = (SplitFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h index 9309be50..e6bc6557 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/split_fp32_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SPLIT_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_SPLIT_FP32_WRAPPER_H_ #include -#include "nnacl/base/split_base.h" +#include "nnacl_c/base/split_base.h" typedef struct { const void *input_ptr_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c index b75da08a..88de73b3 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.c @@ -16,8 +16,8 @@ #include "wrapper/fp32/transpose_fp32_wrapper.h" #include -#include "nnacl/fp32/pack_fp32.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/fp32/pack_fp32.h" +#include "nnacl_c/errorcode.h" int DoTransposeNCHWToNHWC(void *cdata, int task_id, float lhs_scale, float rhs_scale) { TransposeFp32Args *args = (TransposeFp32Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h index 61462ff1..b63a9b12 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/fp32/transpose_fp32_wrapper.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_TRANSPOSE_FP32_WRAPPER_H_ #define MINDSPORE_LITE_MICRO_CODER_OPERATOR_LIBRARY_WRAPPER_FP32_TRANSPOSE_FP32_WRAPPER_H_ #include -#include "nnacl/fp32/transpose_fp32.h" -#include "nnacl/transpose_parameter.h" +#include "nnacl_c/fp32/transpose_fp32.h" +#include "nnacl_c/transpose_parameter.h" typedef struct { const void *input_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c index df8287b6..aac9a581 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/add_int8_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int AddBroadcastInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { AddInt8Args *args = (AddInt8Args *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h index ad4e09c6..8411709a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/add_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_ADD_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_ADD_INT8_WRAPPER_H_ #include -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/add_int8.h" -#include "nnacl/arithmetic_parameter.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/add_int8.h" +#include "nnacl_c/arithmetic_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c index 33777ee4..7d8494d5 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/int8/batchnorm_int8_wrapper.h" -#include "nnacl/int8/batchnorm_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/int8/batchnorm_int8.h" +#include "nnacl_c/errorcode.h" int BatchNormInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { BatchNormArgs *args = (BatchNormArgs *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h index cac6f1ec..28c349cd 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/batchnorm_int8_wrapper.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_BATCHNORM_INT8_WRAPPER_H_ #include -#include "nnacl/batchnorm_parameter.h" +#include "nnacl_c/batchnorm_parameter.h" typedef struct BatchNormArgs { int8_t *in_addr_; int8_t *out_addr_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h index cc23a389..a019056a 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/concat_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONCAT_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONCAT_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/concat_parameter.h" -#include "nnacl/int8/concat_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/concat_parameter.h" +#include "nnacl_c/int8/concat_int8.h" typedef struct { int8_t **inputs_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c index e7c604e3..f9ffb0a7 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.c @@ -15,8 +15,8 @@ */ #include "wrapper/int8/conv1x1_init_int8_wrapper.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/errorcode.h" size_t Conv1x1PackWeightSize(int32_t input_channel, int32_t output_channel, bool support_optimize) { size_t size = 0; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h index 1a721d57..17b411ec 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_init_int8_wrapper.h @@ -19,7 +19,7 @@ #include #include -#include "nnacl/conv_parameter.h" +#include "nnacl_c/conv_parameter.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c index f9df30f2..b8a67647 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.c @@ -15,11 +15,11 @@ */ #include "wrapper/int8/conv1x1_run_int8_wrapper.h" -#include "nnacl/base/conv1x1_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/int8/pack_int8.h" -#include "nnacl/int8/conv1x1_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/base/conv1x1_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/int8/pack_int8.h" +#include "nnacl_c/int8/conv1x1_int8.h" +#include "nnacl_c/errorcode.h" void Pre1x1Trans(Conv1x1Args *args, int8_t *src_input, int8_t *src_output) { args->output_ptr_ = src_output; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h index 6a51d960..f1c94070 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv1x1_run_int8_wrapper.h @@ -19,8 +19,8 @@ #include #include -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "wrapper/base/micro_parameter.h" typedef struct { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h index d2dad072..1663d987 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv3x3_run_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONV3X3_RUN_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONV3X3_RUN_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/int8/conv3x3_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/conv3x3_int8.h" typedef struct { int16_t *input_data; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c index d7454de5..2cf29d3c 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/conv_init_int8_wrapper.c @@ -15,9 +15,9 @@ */ #include "wrapper/int8/conv_init_int8_wrapper.h" -#include "nnacl/op_base.h" -#include "nnacl/int8/matmul_int8.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/op_base.h" +#include "nnacl_c/int8/matmul_int8.h" +#include "nnacl_c/errorcode.h" size_t ConvPackWeightSize(int input_channel, int output_channel, int kernel_plane, bool support_optimize) { size_t up_round_deep; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h index cd5c9e89..21493fe8 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_depthwise_int8_wrapper.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_DEPTHWISE_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_DEPTHWISE_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/int8/conv_depthwise_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/int8/conv_depthwise_int8.h" typedef struct { int8_t *output_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h index 00a4ce9f..166cbbef 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/convolution_int8_wrapper.h @@ -17,10 +17,10 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_CONVOLUTION_INT8_WRAPPER_H_ -#include "nnacl/errorcode.h" -#include "nnacl/conv_parameter.h" -#include "nnacl/matmul_parameter.h" -#include "nnacl/int8/conv_int8.h" +#include "nnacl_c/errorcode.h" +#include "nnacl_c/conv_parameter.h" +#include "nnacl_c/matmul_parameter.h" +#include "nnacl_c/int8/conv_int8.h" typedef struct { int8_t *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h index ea2f340f..87666ed5 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/matmul_int8_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_MATMUL_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_MATMUL_INT8_WRAPPER_H_ #include -#include "nnacl/int8/matmul_int8.h" +#include "nnacl_c/int8/matmul_int8.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c index 81dc1578..f82f18af 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/resize_int8_wrapper.h" -#include "nnacl/errorcode.h" +#include "nnacl_c/errorcode.h" int ResizeInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { ResizeInt8Args *args = (ResizeInt8Args *)cdata; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h index f66d244c..5cbb1620 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/resize_int8_wrapper.h @@ -17,7 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_RESIZE_INT8_WRAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_RESIZE_INT8_WRAPPER_H_ -#include "nnacl/int8/resize_int8.h" +#include "nnacl_c/int8/resize_int8.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c index 5581bdc0..348e39d8 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.c @@ -15,7 +15,7 @@ */ #include "wrapper/int8/slice_int8_wrapper.h" -#include "nnacl/int8/slice_int8.h" +#include "nnacl_c/int8/slice_int8.h" int SliceInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) { SliceArgs *args = (SliceArgs *)(cdata); diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h index e593ef7b..b4e703d7 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/int8/slice_int8_wrapper.h @@ -18,8 +18,8 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_WRAPPER_INT8_SLICE_INT8_WRAPPER_H_ #include -#include "nnacl/slice_parameter.h" -#include "nnacl/kernel/slice.h" +#include "nnacl_c/slice_parameter.h" +#include "nnacl_c/kernel/slice.h" typedef struct SliceArgs { int8_t *input_data_; diff --git a/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c b/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c index 5a452add..c388fec0 100644 --- a/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c +++ b/mindspore-lite/tools/converter/micro/coder/wrapper/thread/micro_core_affinity.c @@ -25,7 +25,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" int GetCpuCoreNum() { int core_num = 1; diff --git a/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h b/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h index b4698a2f..1ae97a2d 100644 --- a/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h +++ b/mindspore-lite/tools/converter/micro/providers/nnie/nnie_micro.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_PROVIDERS_NNIE_NNIE_MICRO_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_PROVIDERS_NNIE_NNIE_MICRO_H_ -#include "nnacl/custom_parameter.h" -#include "nnacl/tensor_c.h" +#include "nnacl_c/custom_parameter.h" +#include "nnacl_c/tensor_c.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore-lite/tools/converter/offline_packing_optimizer.cc b/mindspore-lite/tools/converter/offline_packing_optimizer.cc index c2532388..8a4981fd 100644 --- a/mindspore-lite/tools/converter/offline_packing_optimizer.cc +++ b/mindspore-lite/tools/converter/offline_packing_optimizer.cc @@ -27,7 +27,7 @@ #include "src/common/primitive_t_utils.h" #include "src/common/ops/anf_utils.h" #include "src/common/file_utils.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/converter/ops/while.cc b/mindspore-lite/tools/converter/ops/while.cc index 49f66223..15ef4c01 100644 --- a/mindspore-lite/tools/converter/ops/while.cc +++ b/mindspore-lite/tools/converter/ops/while.cc @@ -19,7 +19,7 @@ #include "tools/converter/ops/while.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/optimizer_manager.cc b/mindspore-lite/tools/converter/optimizer_manager.cc index 44d66dda..e567ac04 100644 --- a/mindspore-lite/tools/converter/optimizer_manager.cc +++ b/mindspore-lite/tools/converter/optimizer_manager.cc @@ -24,7 +24,7 @@ #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "include/registry/pass_base.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc index 8681801f..69309b2f 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_activation_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_activation_parser.h" #include #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc index a7e34625..38a52abe 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc index 8dd36c92..102f78e5 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_batchnorm_parser.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc index 59e41783..fc5cf48f 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_concat_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_concat_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc index 6d749593..766d41b4 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_conv_base_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/caffe/caffe_conv_base_parser.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc index 379ea42c..ac8591e5 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_convolution_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_convolution_parser.h" #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc index 3de900b0..f41c9236 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_crop_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_crop_parser.h" #include #include "infer/crop.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc index ab415486..56c2566a 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_deconvolution_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc index 2fe0e838..e66c0f8b 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_eltwise_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/eltwise.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc index f9adf5a1..8a9c5c09 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_exp_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/exp_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc index 3c711bd5..927a43e0 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_flatten_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_flatten_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc index 882f3542..c0f74409 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_innerproduct_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/cxx_api/full_connection.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc index 8cc1d71c..2e944ba4 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_interp_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_interp_parser.h" #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc index ce29d1d9..f2cc41e7 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -32,7 +32,7 @@ #include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/parser/unify_format.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc index f11c024f..ec47182b 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc @@ -19,7 +19,7 @@ #include "include/securec.h" #include "ir/dtype/type_id.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc index 0ab3ce1e..a997b25d 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_permute_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_permute_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc index d4ff9271..80973264 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_pooling_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/max_pool_fusion.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc index 386716b6..a0ef44ed 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_power_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/pow_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc index e0ce4d03..109bc83f 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_prelu_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_prelu_parser.h" #include #include "infer/cxx_api/prelu_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc index 8ccceb31..1e5ca99b 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_quantize_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/caffe/caffe_quantize_parser.h" #include #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc index 4c4f88e1..7923ae84 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc index fa31d9fd..ecc74869 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_reshape_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_reshape_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc index ac96c8c3..fbe30ac4 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_scale_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_scale_parser.h" #include #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc index 47ef93d3..01c461cb 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_slice_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_slice_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc index 6f19e643..f1ac8833 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/caffe/caffe_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc index ad3c8993..79b0fb98 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_tile_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc b/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc index 2bc24c15..ca174823 100644 --- a/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/caffe/caffe_upsample_parser.cc @@ -19,7 +19,7 @@ #include #include "infer/resize.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc b/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc index e3db1d30..ace278a0 100644 --- a/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc +++ b/mindspore-lite/tools/converter/parser/conv1d_inout_adjust.cc @@ -29,7 +29,7 @@ #include "infer/unsqueeze.h" #include "ops/primitive_c.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc b/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc index 2ca880c6..a027d2d5 100644 --- a/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc +++ b/mindspore-lite/tools/converter/parser/conv2d_transpose_input_adjust.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/conv2d_transpose_input_adjust.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/conv2d_transpose_fusion.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/tools/converter/parser/einsum_adjust.cc b/mindspore-lite/tools/converter/parser/einsum_adjust.cc index d7ea8c30..6e2bf947 100644 --- a/mindspore-lite/tools/converter/parser/einsum_adjust.cc +++ b/mindspore-lite/tools/converter/parser/einsum_adjust.cc @@ -24,7 +24,7 @@ #include "tools/converter/ops/ops_def.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/unsqueeze.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/inputs_adjust.cc b/mindspore-lite/tools/converter/parser/inputs_adjust.cc index 056a6541..2181204f 100644 --- a/mindspore-lite/tools/converter/parser/inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/inputs_adjust.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "ops/primitive_c.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" diff --git a/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc b/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc index 896b1d37..91dc6b48 100644 --- a/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc +++ b/mindspore-lite/tools/converter/parser/lstm_adjust_pass.cc @@ -25,7 +25,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/common/tensor_util.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc index 67061df9..8b4a15df 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_activation_parser.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/softplus.h" #include "infer/selu.h" #include "infer/ops_func_impl/hswish.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc index e58aff53..f20c9e9e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_adder_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_adder_parser.h" #include #include "infer/cxx_api/adder_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc index 73b99356..8e27b2db 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc index 6aa23f15..e53d880c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_argmin_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_argmin_parser.h" #include #include "infer/cxx_api/arg_min_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index b2ecc78a..e04c7666 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -27,7 +27,7 @@ #include "infer/cxx_api/pow_fusion.h" #include "infer/eltwise.h" #include "infer/mod.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc index 24f94ffd..df91c890 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_batchnorm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_batchnorm_parser.h" #include #include "infer/fused_batch_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc index 29cb2b10..4033c8f6 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_biasadd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_biasadd_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc index cc8382d4..9e6927fe 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/onnx/onnx_model_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc index 7922ef6e..1b64bc56 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_clip_parser.h" #include #include "infer/clip.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc index 90ec2fdd..6e64a2af 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_col2im_parser.cc @@ -18,7 +18,7 @@ #include #include #include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/col2im.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc index ca44f90d..d85a2670 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_concat_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_concat_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc index f797a75b..f9492918 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_of_shape_parser.cc @@ -19,7 +19,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "infer/constant_of_shape.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc index 05d0eee9..0e6d2b3f 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_constant_parser.cc @@ -24,7 +24,7 @@ #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/ops/ops_def.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ir/tensor_new.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc index 282a6dc8..88dac606 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv2d_add_parser.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/custom.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc index c95dd15f..03d34522 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_base_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc index 9db3f0ac..f8939a3d 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/infer/conv3d.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc index 97e6fdc1..426949ba 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_conv_transpose_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc index 406bb813..f7b58108 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_adjust.cc @@ -28,7 +28,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/uniform_real.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/node_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc index 5582e396..9f72fb56 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_custom_op_parser.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/ops_def.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/affine_grid.h" #include "infer/histogram.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc index 8df0be87..902f6c0b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_deform_conv2d_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/deformable_conv2d.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { namespace { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc index c0385750..f4b23ad2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_depth_to_space_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_depth_to_space_parser.h" #include #include "infer/depth_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc index 4d99d258..d59c5c2b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_dequantize_linear_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/onnx/onnx_dequantize_linear_parser.h" #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/ops/ops_def.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc index c7fa69d1..f352c211 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/dropout.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc index a31948e6..03479a51 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_einsum_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_einsum_parser.h" #include #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc index 75f6a513..496d7034 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_erf_parser.cc @@ -16,7 +16,7 @@ #include "tools/converter/parser/onnx/onnx_erf_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc index a4ee9ef3..1c07eb30 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_expand_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc index bc1a86a8..afd6f01c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_flash_attention_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc index 67e156b7..b43b6f26 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_element_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gather_element_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc index 4f8f4a4c..12e52f65 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_nd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gather_nd_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc index bbd7c6b6..9bc2ef1a 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gather_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/gather.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc index d004683b..eee72d53 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_given_tensor_fill_parser.cc @@ -21,7 +21,7 @@ #include #include "tools/common/tensor_util.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc index 5b545a9c..3ebbb2c4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample3d_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gridsample3d_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc index 7cec2ead..0fd0fab7 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gridsample_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_gridsample_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc index 4f85a058..2a83e42f 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_gru_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/gru.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/infer/grad/gru_v2_grad.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc index 8d025890..17254fb4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_hardswish_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/ops_func_impl/hswish.h" #include "op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc index 8fc86095..a4c345ec 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_identity_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc index 60066820..90466193 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_if_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc index dee5ed4c..ab244232 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc @@ -30,7 +30,7 @@ #include "infer/multinomial.h" #include "infer/affine_grid.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/node_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc index ac7fe6b9..e86e5b0b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_instance_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_instance_norm_parser.h" #include #include "infer/instance_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc index ecb1354e..270c5bad 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_layer_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_layer_norm_parser.h" #include #include "infer/cxx_api/layer_norm_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc index 5c9643f4..aab99291 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_less_or_equal_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_less_or_equal_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc index 7d77ae54..5a3da48d 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_log_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_log_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc index dec3f2dd..642d12e4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_loop_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "tools/converter/ops/while.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc index e9502e26..c3182f7f 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lp_norm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lp_norm_parser.h" #include #include "infer/lp_normalization.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc index 1e2fce3d..557ce9f0 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lrn_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lrn_parser.h" #include #include "infer/lrn.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc index 0d5f2976..de163a06 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_lstm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_lstm_parser.h" #include #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index b9017ca3..771a1d14 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index f9eeff78..067252c5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -25,7 +25,7 @@ #include "include/registry/node_parser_registry.h" #include "ir/func_graph.h" #include "mindspore/ops/op_def/nn_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc index 549faaff..2655a48b 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -20,7 +20,7 @@ #include #include #include "tools/converter/parser/onnx/onnx_model_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/file_utils.h" #include "utils/ms_utils_secure.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc index ec29bfe5..d94ddd4c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_non_max_suppression_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_non_max_suppression_parser.h" #include #include "infer/non_max_suppression.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc index db618ee1..a3e2d237 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_nonzero_parser.cc @@ -18,7 +18,7 @@ #include #include "tools/converter/parser/onnx/onnx_model_parser.h" #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc index 6c586781..4704e189 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_onehot_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_onehot_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc index 34628488..c68219f4 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_adjust.cc @@ -22,7 +22,7 @@ #include "ops/primitive_c.h" #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" namespace mindspore::lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc index fadbe8c4..b4d8e5cc 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pad_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/pad_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc index c69f2284..314b50e0 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/ops_utils/op_constants.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc index 6620572b..36a1dbe1 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_prompt_flash_attention_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_enum.h" #include "op_def/auto_generate/gen_lite_ops.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc index 95da0e3f..7bd3a8db 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_adjust.cc @@ -24,7 +24,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "tools/converter/ops/ops_def.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc index 2ff60a78..d8c818ad 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_linear_parser.cc @@ -22,7 +22,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc index e548df48..da96c8fd 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_quantize_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_quantize_parser.h" #include #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc index d19164fc..7b8fe37c 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_random_normal_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/random_normal.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc index 5d5be171..400f9f3a 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_range_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_range_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc index f78ac25e..98ad0af9 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index 77addebf..e6df36a2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc index 25761368..0fd70b43 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_resize_parser.cc @@ -21,7 +21,7 @@ #include #include "infer/resize.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc index d3f2b006..ee08a635 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_reverse_sequence_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_reverse_sequence_parser.h" #include #include "infer/reverse_sequence.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc index a8b67660..da6d1e80 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_elements_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_scatter_elements_parser.h" #include #include "infer/scatter_elements.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc index 99d97e83..569f67eb 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_scatter_nd_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_scatter_nd_parser.h" #include #include "infer/scatter_nd_update.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc index 334f320f..6f914950 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_shape_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_shape_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc index 82f71b61..771883b5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_slice_parser.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc index 907efa48..82e7fa62 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_softmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_softmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc index a36b3ca4..dede41d2 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_space_to_depth_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_space_to_depth_parser.h" #include #include "infer/space_to_depth.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc index a84cbd49..a2b8a1d0 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_splice_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/splice.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc index 2b489c11..272ae543 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc index e4c737c3..b0c0d91e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_squeeze_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc index 691f9460..127f47b5 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc index f924526d..2d101145 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_topk_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_topk_parser.h" #include #include "infer/cxx_api/topk_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc index e6812ffe..c6af3e0e 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc index cb84a25e..8bb81b40 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_trilu_parser.cc @@ -19,7 +19,7 @@ #include #include "infer/tril.h" #include "infer/ops_func_impl/triu.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "op_def/auto_generate/gen_lite_ops.h" diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc index e472f998..953ca617 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_unsqueeze_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/unsqueeze.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index 46072f0c..09152b31 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc index 51d4e54f..6c248599 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_where_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/onnx/onnx_where_parser.h" #include #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/parser_utils.cc b/mindspore-lite/tools/converter/parser/parser_utils.cc index 6b6e9e86..9cb61926 100644 --- a/mindspore-lite/tools/converter/parser/parser_utils.cc +++ b/mindspore-lite/tools/converter/parser/parser_utils.cc @@ -37,7 +37,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/format/to_format_base.h" #include "tools/optimizer/common/pass_manager_extends.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/common.h" #include "tools/converter/parser/conv2d_transpose_input_adjust.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc index dc83e894..baf428e5 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_activation_parser.cc @@ -18,7 +18,7 @@ #include #include "include/securec.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc index 70d46d15..de6525e2 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_argmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_argmax_parser.h" #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc index 2c8e2fd6..1d828584 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_arithmetic_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/div_fusion.h" #include "infer/cxx_api/sub_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc index eab85684..e7fd0cd2 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_batchnorm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_batchnorm_parser.h" #include #include "infer/fused_batch_norm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc index a83d920f..f00cb239 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_conv_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite { PrimitiveCPtr PytorchConvParser::Parse(const torch::jit::Node *torch_node, std::vector *input_indices) { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc index 57ae7bce..ef98afe6 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_cumsum_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_cumsum_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc index ac79fa0e..d70c38c0 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_elementop_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_elementop_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc index 56daceeb..0e84d2f1 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_embedding_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_embedding_parser.h" #include #include "infer/ops_func_impl/gather.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc index 1ee8b9d6..7dcb239c 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_flatten_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_flatten_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc index 2dd5300d..b1d6728d 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_gather_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_gather_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc index 871a1b9e..7fce39b9 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_list_construct_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_list_construct_parser.h" #include #include "infer/make_tuple.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc index e6458245..16d955de 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_logsoftmax_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_logsoftmax_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc index b0a74206..2ec44d43 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_adjust.cc @@ -20,7 +20,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/common/tensor_util.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc index b61234b8..922b2d9f 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_lstm_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_lstm_parser.h" #include #include "infer/lstm.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc index eaf24192..dec94040 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc index c53ccb31..182bf2b9 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc @@ -30,7 +30,7 @@ #include "tools/converter/parser/pytorch/torch_graph_transfrom.h" #include "src/common/file_utils.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h b/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h index 399d0851..aba1816d 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_node_parser.h @@ -29,7 +29,7 @@ #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc index 04a65415..fe0d7a7b 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_non_max_suppression_parser.h" #include #include "infer/non_max_suppression.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc index 88f425c1..ef2855c7 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_permute_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc index 9ed77aca..5333e3d8 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pool_parser.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" #include "include/registry/converter_context.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc index 50e3fd56..98a6704d 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_pow_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/pytorch/pytorch_pow_parser.h" #include #include "infer/cxx_api/pow_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc index eac82dff..912641da 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_reshape_parser.cc @@ -18,7 +18,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/unsqueeze.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc index 47808075..e3bef32f 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc index 6b3d6319..9cf17ca9 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_to_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/pytorch/pytorch_node_parser.h" #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc index 8f73044d..a20ceef4 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unaryop_parser.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/exp_fusion.h" #include "infer/ops_func_impl/tan.h" #include "infer/eltwise.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc index b9db6643..60e5601a 100644 --- a/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc +++ b/mindspore-lite/tools/converter/parser/pytorch/pytorch_unstack_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/unstack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc b/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc index 088298b4..35832e7b 100644 --- a/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc +++ b/mindspore-lite/tools/converter/parser/tf/functionalize_cond.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "include/errorcode.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/return.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc b/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc index 1f09ac49..cbadd724 100644 --- a/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc +++ b/mindspore-lite/tools/converter/parser/tf/functionalize_control_op_pass.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/tf/functionalize_while.h" #include "tools/converter/parser/tf/functionalize_cond.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc b/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc index 5c457b5b..f14b76dc 100644 --- a/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc +++ b/mindspore-lite/tools/converter/parser/tf/remove_ineffective_control_flow.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/tf/tf_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc index 3a3e5d90..51b1cece 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_fake_quant_parser.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "tools/converter/parser/tf/tf_fake_quant_parser.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/ops/ops_def.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h b/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h index e3779f52..ec17e5e8 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h +++ b/mindspore-lite/tools/converter/parser/tf/tf_node_parser.h @@ -25,7 +25,7 @@ #include "proto/graph.pb.h" #include "ops/primitive_c.h" #include "utils/check_convert_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/parser_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/op_name.h" diff --git a/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc b/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc index 85786141..40617f35 100644 --- a/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc +++ b/mindspore-lite/tools/converter/parser/tf/tf_sparse_to_dense_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/sparse_to_dense.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc b/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc index 02decf59..e4cf2609 100644 --- a/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc +++ b/mindspore-lite/tools/converter/parser/tf_bidirection_gru_cf_fusion.cc @@ -31,7 +31,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" #include "tools/converter/ops/ops_def.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc index db1bf312..e3e1e36c 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -20,7 +20,7 @@ #include "tools/converter/parser/tflite/tflite_util.h" #include "infer/cxx_api/prelu_fusion.h" #include "infer/cxx_api/activation.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc index 4a35715d..798bdf74 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index 68cd48f0..7770dde6 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/arg_max_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index 0044e9f0..8e164ce5 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/arg_min_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index c1d694f4..abfad806 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -25,7 +25,7 @@ #include "infer/cxx_api/exp_fusion.h" #include "infer/cxx_api/pow_fusion.h" #include "infer/squared_difference.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc index b5215048..1eb54bc2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_matmul_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index f3326e4f..0782cfc0 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/batch_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index 98f63990..64f4c316 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc index 5a5bf434..9fa6ce66 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc index 4740f37c..f3d83692 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 91cc09a6..384b8cd2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc index b14924b6..c5dcd895 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_conv_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc index f815fd5e..ac388def 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -30,7 +30,7 @@ #include "infer/fft_imag.h" #include "infer/mfcc.h" #include "infer/rfft.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 882c8324..3c283bc0 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/depth_to_space.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 20c31db0..4946d92f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -18,7 +18,7 @@ #include #include "infer/quant_dtype_cast.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 780b889a..558d4385 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc index 66deea42..884b4cd6 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/fill.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 5e9c32c2..dd86243b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/full_connection.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index 9b74e6e8..958d0447 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc index 1b920b44..70af88d6 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc index 1ba49291..6026bcf0 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/hashtable_lookup.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc index 65bb608e..9c6fe8e2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_if_parser.cc @@ -18,7 +18,7 @@ #include "tools/converter/parser/tflite/tflite_if_parser.h" #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc index 34052d26..eb563769 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_inputs_adjust.cc @@ -28,7 +28,7 @@ #include "infer/space_to_batch_nd.h" #include "infer/space_to_depth.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc index 9a11f194..0881aca7 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/l2_normalize_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc index 29122132..00015db0 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_log_softmax_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc index 2f04244c..978219ae 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index d971c9e7..a321c85d 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/lrn.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc index 1097bfcb..22082e17 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/lsh_projection.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc index 50eff237..04a54b37 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_matmul_parser.cc @@ -17,7 +17,7 @@ #include "tools/converter/parser/tflite/tflite_matmul_parser.h" #include #include "infer/cxx_api/mat_mul_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc index 5f9c57d6..51f8729a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -34,7 +34,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/lite_model_parser_creator.h" #include "tools/converter/parser/unify_format.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "infer/make_tuple.h" #include "infer/return.h" diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index 537aaee8..aa0ca448 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc index c19b5218..75dfd50a 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/pad_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index f2d0aa7a..0e51c1da 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -20,7 +20,7 @@ #include #include "infer/cxx_api/avg_pool_fusion.h" #include "infer/cxx_api/max_pool_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index fdd322b0..ec2ffeea 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -18,7 +18,7 @@ #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/quant_dtype_cast.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc index 40d52b46..a7c9f485 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc index 46e9399b..aeb7842b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index e12a1999..9a28ef6f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/reduce_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index c464615f..c5c6e05f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 4ecf4a00..cfb12176 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -20,7 +20,7 @@ #include #include #include "infer/resize.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index 12548cd4..f90b380f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index 4d37da00..998e44d2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/reverse_sequence.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index 44145e65..4bb69435 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc index 37790786..81fa131f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc index 3e38d87e..2aa64866 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/skip_gram.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc index 0f35b0d0..fbf79aeb 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/cxx_api/slice_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 07d28c7c..2ddfaa87 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index e2278c0d..1dff3ae4 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/space_to_batch_nd.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 758bd392..77a698e5 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/space_to_depth.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index dd6be778..f9eb6d76 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/sparse_to_dense.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc index ad54a6a9..a9ac3aa2 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index b9735021..eaa19d3d 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index 073f170a..dcf8c051 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc index 26dece90..d5404a51 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -18,7 +18,7 @@ #include #include #include "infer/stack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index bf967d78..9c0f6a1f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc index 6ab7c327..8ad7c2e5 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/tile_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index f775424e..6d03b339 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/cxx_api/topk_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index d0f19780..ecfbf597 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -18,7 +18,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc index ee124cb4..b134cef7 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/unique.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index f75c390d..1780068f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/unstack.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc index 685277f2..7b06d82b 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_util.cc @@ -21,7 +21,7 @@ #include #include "src/common/log_adapter.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc index 51322f79..91bf6f46 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -19,7 +19,7 @@ #include #include #include "infer/where.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc index d04d5b80..9f099387 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_while_parser.cc @@ -19,7 +19,7 @@ #include #include #include "tools/converter/ops/while.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index a92cede1..c79759e4 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -19,7 +19,7 @@ #include #include #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace lite { diff --git a/mindspore-lite/tools/converter/parser/unify_format.cc b/mindspore-lite/tools/converter/parser/unify_format.cc index 2c814b4f..1dd15904 100644 --- a/mindspore-lite/tools/converter/parser/unify_format.cc +++ b/mindspore-lite/tools/converter/parser/unify_format.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/quantizer/cle_pattern.cc b/mindspore-lite/tools/converter/quantizer/cle_pattern.cc index fa8a0035..949fc5b0 100644 --- a/mindspore-lite/tools/converter/quantizer/cle_pattern.cc +++ b/mindspore-lite/tools/converter/quantizer/cle_pattern.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/quantizer/debug_info_manager.h b/mindspore-lite/tools/converter/quantizer/debug_info_manager.h index 498fa169..936898c5 100644 --- a/mindspore-lite/tools/converter/quantizer/debug_info_manager.h +++ b/mindspore-lite/tools/converter/quantizer/debug_info_manager.h @@ -24,7 +24,7 @@ #include #include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/graphdef_transform.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/statistic_utils.h" #include "src/litert/lite_model.h" #include "src/tensor.h" diff --git a/mindspore-lite/tools/converter/quantizer/fse_decoder.cc b/mindspore-lite/tools/converter/quantizer/fse_decoder.cc index 968dd3d0..0595321f 100644 --- a/mindspore-lite/tools/converter/quantizer/fse_decoder.cc +++ b/mindspore-lite/tools/converter/quantizer/fse_decoder.cc @@ -20,7 +20,7 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::lite::quant { namespace { diff --git a/mindspore-lite/tools/converter/quantizer/fse_encoder.cc b/mindspore-lite/tools/converter/quantizer/fse_encoder.cc index 8da15c87..dda1d59f 100644 --- a/mindspore-lite/tools/converter/quantizer/fse_encoder.cc +++ b/mindspore-lite/tools/converter/quantizer/fse_encoder.cc @@ -22,7 +22,7 @@ #include "ir/dtype/type_id.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/statistic_utils.h" diff --git a/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc index 7f30b48e..3f1aba2f 100644 --- a/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore-lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -37,7 +37,7 @@ #include "tools/common/tensor_util.h" #include "src/common/utils.h" #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/quantizer/bias_correction_strategy.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/converter/quantizer/gptq.h b/mindspore-lite/tools/converter/quantizer/gptq.h index c9ce4afe..8a9553e4 100644 --- a/mindspore-lite/tools/converter/quantizer/gptq.h +++ b/mindspore-lite/tools/converter/quantizer/gptq.h @@ -21,7 +21,7 @@ #include #include "tools/converter/quantizer/quantizer.h" #include "src/tensor.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" #include "tools/converter/quantizer/gptq_quantizer.h" namespace mindspore::lite::quant { diff --git a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h index 37ecbc95..cf183900 100644 --- a/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h +++ b/mindspore-lite/tools/converter/quantizer/gptq_quantizer.h @@ -29,7 +29,7 @@ #include "tools/converter/cxx_api/converter_para.h" #include "ir/func_graph.h" #include "ir/anf.h" -#include "nnacl/matmul_parameter.h" +#include "nnacl_c/matmul_parameter.h" namespace mindspore { namespace lite::quant { diff --git a/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc b/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc index c7d3b726..50f6c7c4 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "include/backend/optimizer/graph_optimizer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc b/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc index 187c2c0c..9d819568 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_helper/ffn_full_quant.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "include/backend/optimizer/graph_optimizer.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_param_holder.h b/mindspore-lite/tools/converter/quantizer/quant_param_holder.h index 35ff9d07..0bf8e2e9 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_param_holder.h +++ b/mindspore-lite/tools/converter/quantizer/quant_param_holder.h @@ -22,7 +22,7 @@ #include #include "ir/anf.h" #include "schema/inner/model_generated.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/quantizer/quant_params.h" diff --git a/mindspore-lite/tools/converter/quantizer/quant_strategy.cc b/mindspore-lite/tools/converter/quantizer/quant_strategy.cc index eb802a64..17f5bb39 100644 --- a/mindspore-lite/tools/converter/quantizer/quant_strategy.cc +++ b/mindspore-lite/tools/converter/quantizer/quant_strategy.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/converter/quantizer/quantize_util.cc b/mindspore-lite/tools/converter/quantizer/quantize_util.cc index b7fd3fad..02b3c4c3 100644 --- a/mindspore-lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore-lite/tools/converter/quantizer/quantize_util.cc @@ -47,7 +47,7 @@ #include "tools/converter/parser/parser_utils.h" #include "mindspore/ops/op_def/other_ops.h" #include "utils/anf_utils.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/converter/quantizer/smooth_quant.cc b/mindspore-lite/tools/converter/quantizer/smooth_quant.cc index 9d3c8d5e..2346c111 100644 --- a/mindspore-lite/tools/converter/quantizer/smooth_quant.cc +++ b/mindspore-lite/tools/converter/quantizer/smooth_quant.cc @@ -24,9 +24,9 @@ #include "tools/converter/quantizer/insert_quant_node_manager.h" #include "tools/optimizer/common/gllo_utils.h" #include "thread/threadpool.h" -#include "nnacl/fp32/scale_fp32.h" +#include "nnacl_c/fp32/scale_fp32.h" #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl_c/fp32/matmul_fp32.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" diff --git a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc index 34f192a3..3bf47d7c 100644 --- a/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc +++ b/mindspore-lite/tools/converter/quantizer/split_shared_bias.cc @@ -18,7 +18,7 @@ #include #include #include "tools/optimizer/common/gllo_utils.h" -#include "mindspore/ops/kernel/cpu/nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/converter/quantizer/quant_params.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/converter/registry/CMakeLists.txt b/mindspore-lite/tools/converter/registry/CMakeLists.txt index 31e4c75f..53595758 100644 --- a/mindspore-lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore-lite/tools/converter/registry/CMakeLists.txt @@ -16,7 +16,7 @@ set(REG_SRC ${CONVERT_REG_SRC} ${KERNEL_REG_DIR}/../extendrt/delegate/plugin/tensorrt_executor_plugin.cc ${KERNEL_REG_DIR}/../extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc ${CONVERTER_DIR}/converter_context.cc - ${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c + ${NNACL_DIR}/tensor_c_utils.c ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc ) set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS diff --git a/mindspore-lite/tools/converter/registry/model_parser_registry.cc b/mindspore-lite/tools/converter/registry/model_parser_registry.cc index 043e0081..1421775a 100644 --- a/mindspore-lite/tools/converter/registry/model_parser_registry.cc +++ b/mindspore-lite/tools/converter/registry/model_parser_registry.cc @@ -17,7 +17,7 @@ #include "include/registry/model_parser_registry.h" #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace registry { diff --git a/mindspore-lite/tools/converter/registry/pass_registry.cc b/mindspore-lite/tools/converter/registry/pass_registry.cc index 4e81081e..7bbaccb5 100644 --- a/mindspore-lite/tools/converter/registry/pass_registry.cc +++ b/mindspore-lite/tools/converter/registry/pass_registry.cc @@ -20,7 +20,7 @@ #include #include #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace registry { diff --git a/mindspore-lite/tools/cropper/build_cropper_config.sh b/mindspore-lite/tools/cropper/build_cropper_config.sh index e4195a5c..b787f84a 100644 --- a/mindspore-lite/tools/cropper/build_cropper_config.sh +++ b/mindspore-lite/tools/cropper/build_cropper_config.sh @@ -252,23 +252,23 @@ getCommonFile() { while IFS='' read -r line; do mindrt_files_h+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/core/mindrt/include/thread/*.h) others_files_h=( "${MINDSPORE_LITE_HOME}"/src/litert/infer_manager.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer_register.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer_register.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_utils.h "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/populate_register.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/op_base.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/op_base.h "${MINDSPORE_HOME}"/mindspore/core/include/ir/dtype/type_id.h "${MINDSPORE_HOME}"/mindspore/core/include/utils/overload.h "${MINDSPORE_LITE_HOME}"/tools/common/option.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/common_infer.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/errorcode.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/common_func.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c.h - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/intrinsics/ms_simd_instructions_fp16.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/common_infer.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/errorcode.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/common_func.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c.h + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.h "${MINDSPORE_HOME}"/mindspore/core/include/utils/log_adapter.h "${MINDSPORE_HOME}"/mindspore/core/include/ir/api_tensor_impl.h "${MINDSPORE_LITE_HOME}"/src/litert/cxx_api/tensor/tensor_impl.h @@ -312,29 +312,29 @@ getCommonFile() { ) # sava all assembly files assembly_files=() - while IFS='' read -r line; do assembly_files+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/assembly/*/*.S) + while IFS='' read -r line; do assembly_files+=("$line"); done < <(ls ${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/assembly/*/*.S) others_files_c=( - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/errorcode.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/errorcode.c "${MINDSPORE_LITE_HOME}"/src/litert/infer_manager.cc "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/populate_register.cc "${MINDSPORE_LITE_HOME}"/src/common/ops/populate/custom_populate.cc - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/infer_register.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/infer/common_infer.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/infer_register.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/shape_fusion_infer.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/infer/common_infer.c "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/fp32/shape_fusion_fp32.cc "${MINDSPORE_HOME}"/mindspore/core/utils/status.cc "${MINDSPORE_HOME}"/mindspore/core/utils/log_adapter.cc - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/kernel.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/tensorlist_c_utils.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/base/format_transpose.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/base/cast_base.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/transpose_fp32.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/pack_fp32.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp16/pack_fp16.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.c - "${MINDSPORE_HOME}"/mindspore/ops/kernel/cpu/nnacl/nnacl_common.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/kernel.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensor_c_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/tensorlist_c_utils.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/base/format_transpose.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/base/cast_base.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/transpose_fp32.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp16/pack_fp16.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/fp32/pack_fp32_opt.c + "${MINDSPORE_LITE_HOME}"/src/litert/kernel/cpu/nnacl_c/nnacl_common.c ) all_files=("${src_files[@]}" "${regist_files[@]}" "${common_files[@]}" "${runtime_files_cc[@]}" "${others_files_c[@]}" "${assembly_files[@]}" "${nnacl_files_cc[@]}" "${mindrt_files[@]}" @@ -428,7 +428,7 @@ getCommonFile getTrainCommonFile # get src/common/ops getOpsFile "REG_POPULATE\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/common/ops/populate" "prototype" & -getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/infer" "prototype" & +getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/infer" "prototype" & # support for cpu getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat32, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeFloat32" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat16, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeFloat16" & @@ -436,11 +436,11 @@ getOpsFile "REG_KERNEL\(.*?, kNumberTypeInt8, PrimitiveType_" "${MINDSPORE_LITE_ getOpsFile "REG_KERNEL\(.*?, kNumberTypeInt32, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeInt32" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeBool, PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu" "kNumberTypeInt32" & #support for nnacl kernel -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeFloat32" "kNumberTypeFloat32" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeFloat16" "kNumberTypeFloat16" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt8" "kNumberTypeInt8" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt32" "kNumberTypeInt32" & -getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_HOME}/mindspore/ops/kernel/cpu/nnacl/kernel" "kNumberTypeInt32" "kNumberTypeBool" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeFloat32" "kNumberTypeFloat32" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeFloat16" "kNumberTypeFloat16" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt8" "kNumberTypeInt8" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt32" "kNumberTypeInt32" & +getNnaclKernelFile "REG_KERNEL_CREATOR\(PrimType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl_c/kernel" "kNumberTypeInt32" "kNumberTypeBool" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeFloat32" "kNumberTypeFloat32" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeFloat16" "kNumberTypeFloat16" & getNnaclKernelFile "NNACL_KERNEL\(PrimitiveType_" "${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu/nnacl" "kNumberTypeInt8" "kNumberTypeInt8" & diff --git a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc b/mindspore-lite/tools/graph_kernel/common/infer_shape.cc index 304cce16..24750e33 100644 --- a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc +++ b/mindspore-lite/tools/graph_kernel/common/infer_shape.cc @@ -22,9 +22,9 @@ #include "schema/model_generated.h" #include "src/tensor.h" #include "src/common/utils.h" -#include "nnacl/infer/common_infer.h" -#include "nnacl/infer/infer_register.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/infer/common_infer.h" +#include "nnacl_c/infer/infer_register.h" +#include "nnacl_c/custom_parameter.h" namespace mindspore::graphkernel { using mindspore::lite::RET_ERROR; diff --git a/mindspore-lite/tools/graph_kernel/common/utils.h b/mindspore-lite/tools/graph_kernel/common/utils.h index ce7afd0d..78a6f184 100644 --- a/mindspore-lite/tools/graph_kernel/common/utils.h +++ b/mindspore-lite/tools/graph_kernel/common/utils.h @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/tensor_c.h" +#include "nnacl_c/tensor_c.h" #include "common/kernel_build_info.h" #include "include/backend/kernel_info.h" diff --git a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h index f1afd16e..8a748554 100644 --- a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h +++ b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h @@ -20,7 +20,7 @@ #include #include #include "src/litert/lite_kernel.h" -#include "nnacl/custom_parameter.h" +#include "nnacl_c/custom_parameter.h" namespace mindspore::kernel { using AkgParallelLambda = int (*)(int task_id, int num_task, void *cdata); diff --git a/mindspore-lite/tools/lite_exporter/anf_exporter.cc b/mindspore-lite/tools/lite_exporter/anf_exporter.cc index 93d4fbe5..60f1bf4e 100644 --- a/mindspore-lite/tools/lite_exporter/anf_exporter.cc +++ b/mindspore-lite/tools/lite_exporter/anf_exporter.cc @@ -31,7 +31,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/ops_utils/op_utils.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/depend.h" #include "infer/cxx_api/partial_fusion.h" #include "infer/make_tuple.h" diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.cc b/mindspore-lite/tools/lite_exporter/fetch_content.cc index 0b904a96..d0456719 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.cc +++ b/mindspore-lite/tools/lite_exporter/fetch_content.cc @@ -25,7 +25,7 @@ #include "mindapi/base/format.h" #include "mindspore/ops/op_def/framework_ops.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/ops/anf_utils.h" #include "src/common/ops/populate/populate_register.h" diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.h b/mindspore-lite/tools/lite_exporter/fetch_content.h index 5b5dadf1..3b370c68 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.h +++ b/mindspore-lite/tools/lite_exporter/fetch_content.h @@ -24,7 +24,7 @@ #include "ir/primitive.h" #include "ir/func_graph.h" #include "src/common/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/common/format_utils.cc b/mindspore-lite/tools/optimizer/common/format_utils.cc index 6806a3ce..1dd45fe2 100644 --- a/mindspore-lite/tools/optimizer/common/format_utils.cc +++ b/mindspore-lite/tools/optimizer/common/format_utils.cc @@ -64,7 +64,7 @@ #include "infer/deformable_conv2d.h" #include "infer/roi_align.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/graph_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.cc b/mindspore-lite/tools/optimizer/common/gllo_utils.cc index ce5bbb53..90bbeaa3 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.cc @@ -36,7 +36,7 @@ #include "frontend/operator/ops.h" #include "include/backend/optimizer/helper.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/helper.h" diff --git a/mindspore-lite/tools/optimizer/common/helper.cc b/mindspore-lite/tools/optimizer/common/helper.cc index b519cc21..1fde0eb3 100644 --- a/mindspore-lite/tools/optimizer/common/helper.cc +++ b/mindspore-lite/tools/optimizer/common/helper.cc @@ -20,7 +20,7 @@ #include #include "tools/optimizer/common/helper.h" #include "mindspore/ops/op_def/sequence_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc b/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc index dcf1a58f..e809de0e 100644 --- a/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc +++ b/mindspore-lite/tools/optimizer/common/multiple_pattern_process_pass.cc @@ -16,7 +16,7 @@ #include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/helper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore::opt { AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc b/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc index 3ea6c6b3..39cd9f70 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_along_infershape.cc @@ -18,7 +18,7 @@ #include "tools/optimizer/const_fold/fold_along_infershape.h" #include #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc b/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc index 63ff7bab..1cb1982a 100644 --- a/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc +++ b/mindspore-lite/tools/optimizer/const_fold/fold_with_infershape.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc b/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc index 135ca138..a935cecd 100644 --- a/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc +++ b/mindspore-lite/tools/optimizer/fisson/eliminate_concat_split.cc @@ -29,7 +29,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/parallel/spliter.h" #include "src/common/log_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fisson/fisson_util.cc b/mindspore-lite/tools/optimizer/fisson/fisson_util.cc index b5f6decc..c751ffd7 100644 --- a/mindspore-lite/tools/optimizer/fisson/fisson_util.cc +++ b/mindspore-lite/tools/optimizer/fisson/fisson_util.cc @@ -26,7 +26,7 @@ #include "infer/make_tuple.h" #include "tools/optimizer/parallel/spliter.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" #include "include/registry/converter_context.h" diff --git a/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc b/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc index 857fb093..4dc8858f 100644 --- a/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc +++ b/mindspore-lite/tools/optimizer/fisson/iter_node_outputs.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/fisson/iter_node_outputs.h" #include "tools/optimizer/parallel/spliter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc b/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc index 332bdf4d..dd2b6036 100644 --- a/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc +++ b/mindspore-lite/tools/optimizer/fisson/multi_conv_split_pass.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; diff --git a/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc b/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc index 839d53fb..baf877a7 100644 --- a/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc +++ b/mindspore-lite/tools/optimizer/fisson/node_out_shapes.cc @@ -19,7 +19,7 @@ #include #include #include "tools/optimizer/parallel/spliter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc b/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc index de5b5aa4..ff001d9a 100644 --- a/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc +++ b/mindspore-lite/tools/optimizer/format/delete_redundant_transpose.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/common/node_util.h" #include "tools/converter/quantizer/quant_params.h" diff --git a/mindspore-lite/tools/optimizer/format/to_format_base.cc b/mindspore-lite/tools/optimizer/format/to_format_base.cc index 87257b46..f161f18c 100644 --- a/mindspore-lite/tools/optimizer/format/to_format_base.cc +++ b/mindspore-lite/tools/optimizer/format/to_format_base.cc @@ -26,7 +26,7 @@ #include "src/common/utils.h" #include "tools/common/tensor_util.h" #include "tools/converter/parser/parser_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc index 0720af1f..cb372a04 100644 --- a/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/activation_fusion.cc @@ -21,7 +21,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" @@ -48,11 +48,13 @@ STATUS DoFusion(CNodePtr cur_cnode, const CNodePtr &pre_cnode) { MS_CHECK_TRUE_MSG(cur_act_prim->GetAttr(ops::kMaxVal) != nullptr, RET_ERROR, "Get max value failed."); MS_CHECK_TRUE_MSG(pre_act_prim->GetAttr(ops::kMinVal) != nullptr, RET_ERROR, "Get min value failed."); MS_CHECK_TRUE_MSG(cur_act_prim->GetAttr(ops::kMinVal) != nullptr, RET_ERROR, "Get min value failed."); - auto pre_max_val = - pre_act_type == RELU ? FLT_MAX : pre_act_type == RELU6 ? kValueThreshold6 : pre_act_prim->get_max_val(); + auto pre_max_val = pre_act_type == RELU ? FLT_MAX + : pre_act_type == RELU6 ? kValueThreshold6 + : pre_act_prim->get_max_val(); auto pre_min_val = (pre_act_type == RELU || pre_act_type == RELU6) ? 0 : pre_act_prim->get_min_val(); - auto cur_max_val = - cur_act_type == RELU ? FLT_MAX : cur_act_type == RELU6 ? kValueThreshold6 : cur_act_prim->get_max_val(); + auto cur_max_val = cur_act_type == RELU ? FLT_MAX + : cur_act_type == RELU6 ? kValueThreshold6 + : cur_act_prim->get_max_val(); auto cur_min_val = (cur_act_type == RELU || cur_act_type == RELU6) ? 0 : cur_act_prim->get_min_val(); auto new_max_val = std::min(pre_max_val, cur_max_val); auto new_min_val = std::max(pre_min_val, cur_min_val); diff --git a/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc index 484f8819..da34ced9 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_activation_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/add_fusion.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc index 27b90a50..1bd9fb8f 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_concat_activation_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/add_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc index cb3aa516..6b50a3cf 100644 --- a/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/add_layernorm_fusion.cc @@ -30,7 +30,7 @@ #include "include/common/utils/anfalgo.h" #include "include/backend/anf_runtime_algorithm.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/graph/node_infershape.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc b/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc index db29238c..fc50cdd7 100644 --- a/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/adjust_col2im_pass.cc @@ -28,7 +28,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/range_v2.h" #include "mindspore/ops/op_def/image_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc index f4c45cca..d627d93f 100644 --- a/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/affine_activation_fusion.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "infer/cxx_api/activation.h" #include "infer/affine.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc b/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc index fe0156ea..bb043793 100644 --- a/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/affine_fusion.cc @@ -25,7 +25,7 @@ #include "infer/splice.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc index d2176d93..b62990a4 100644 --- a/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/antiquant_add_mul_matmul_allreduce_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/infer/all_reduce.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h" #include "ir/anf.h" diff --git a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc index f231088e..5fba9935 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchmatmul_fusion.cc @@ -27,7 +27,7 @@ #include "tools/converter/quantizer/quantize_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc index 85d20831..dfc8fc24 100644 --- a/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc b/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc index 513d06a0..c9bec643 100644 --- a/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/cast_fusion.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc index f519a7ee..730d4563 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index a23f4a7a..da78af0a 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -23,7 +23,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindapi/base/types.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc index 53b1f07c..f9cc8268 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -20,7 +20,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc index 30d9022f..2ea85074 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_conv_fusion.cc @@ -22,7 +22,7 @@ #include "tools/common/tensor_util.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc index 44d868ec..0851cab2 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_pad_fusion.cc @@ -25,7 +25,7 @@ #include "infer/cxx_api/pad_fusion.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops/primitive_c.h" #include "ops_utils/op_utils.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc index a3e03803..ea98d48a 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_scale_fusion.cc @@ -20,7 +20,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc index c1363d76..cc57a7fc 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc index 284aa043..c9410488 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_tuple_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc b/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc index df594472..3e9ffd3d 100644 --- a/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/sequence_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" namespace mindspore::opt { diff --git a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc index 22b9684d..24c7fd96 100644 --- a/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/decoder_layer_fusion.cc @@ -26,7 +26,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc index 6990f102..afc29ab5 100644 --- a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc @@ -30,7 +30,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" @@ -992,7 +992,8 @@ STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const Eq } } act_type_ = (is_position_bias_) ? (ActType::ActType_Relu) - : (is_fast_gelu_) ? (ActType::ActType_FastGelu) : (ActType::ActType_Gelu); + : (is_fast_gelu_) ? (ActType::ActType_FastGelu) + : (ActType::ActType_Gelu); if (!is_position_bias_ && !is_use_past_ && !is_query_layer_) { if (!IsActGELU(func_graph, equiv, is_act_)) { return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc index 1c21920d..455487af 100644 --- a/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/expanddims_reshape_fusion.cc @@ -20,7 +20,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc b/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc index 47ac5177..94bcd429 100644 --- a/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/ffn_custom_pass.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/infer/custom.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/common/string_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc b/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc index e210aaa1..91cdb1cc 100644 --- a/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/ffn_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "mindspore/ops/infer/custom.h" #include "infer/f_f_n.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc b/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc index c89c64b3..fbcdb17b 100644 --- a/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc +++ b/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion_for_custom.cc @@ -18,7 +18,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/cxx_api/flash_attention.h" diff --git a/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc b/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc index 56af4daf..0502d6df 100644 --- a/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/flash_attention_tik_fusion.cc @@ -19,7 +19,7 @@ #include "op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/nn_ops.h" #include "infer/custom.h" -#include "nnacl/base/cast_base.h" +#include "nnacl_c/base/cast_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc index 848808ab..614be3ad 100644 --- a/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/fullconnected_add_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/full_connection.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc b/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc index 58b5306f..e31920fc 100644 --- a/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/fullconnected_fusion.cc @@ -23,7 +23,7 @@ #include "infer/cxx_api/full_connection.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc index 61fd1096..92b0f7f7 100644 --- a/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/gelu_fusion.cc @@ -21,7 +21,7 @@ #include "infer/cxx_api/activation.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc index f6f18a7f..2e4dcddb 100644 --- a/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/glu_fusion.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "include/common/utils/utils.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc index f1c70f86..295d33cb 100644 --- a/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/groupnorm_fusion.cc @@ -26,7 +26,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/ops/ops_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc b/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc index 92f27e07..6a7e7088 100644 --- a/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/hard_swish_fusion.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/fusion/hard_swish_fusion.h" #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/activation.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc index a5d90a54..a85a38d8 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_assign_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc index 9ce4f27b..7fca7f7b 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_concat_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_k.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc index 31f36d95..51229192 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_load_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_k.h" diff --git a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc index 07f6aaa3..a3210737 100644 --- a/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/kv_cache_mgr_one_branch_fusion.cc @@ -23,7 +23,7 @@ #include "src/common/log_adapter.h" #include "infer/splice.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/math_ops.h" diff --git a/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc index 02869e99..fedb5bf1 100644 --- a/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/leaky_relu_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/leaky_relu.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc index 187b9a12..918b634e 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc index 11798967..42110b94 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_add_fusion.cc @@ -24,7 +24,7 @@ #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/mat_mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc index 761c7dc3..ad98cfbe 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_allreduce_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/infer/all_reduce.h" #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h" #include "ir/anf.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc index 29944d5e..78cc55c3 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/mat_mul_fusion.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc index fa9090a2..1f5ea603 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_scale_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc index a32fc34d..b10a2587 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_activation_fusion.cc @@ -20,7 +20,7 @@ #include "infer/cxx_api/activation.h" #include "infer/cxx_api/mul_fusion.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc index 75e1211c..68d334b4 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_add_fusion.cc @@ -20,7 +20,7 @@ #include #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/add_fusion.h" #include "infer/cxx_api/mul_fusion.h" #include "infer/cxx_api/scale_fusion.h" diff --git a/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc index 80f3771e..bc1025b2 100644 --- a/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/mul_reduce_fusion.cc @@ -29,7 +29,7 @@ #include "infer/cxx_api/mul_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index 06fac39c..cebc52ee 100644 --- a/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -30,7 +30,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/tuple_get_item.h" #include "tools/common/tensor_util.h" #include "ops_utils/op_utils.h" diff --git a/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc b/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc index 94b50f3b..0b2e98eb 100644 --- a/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/norm_fusion.cc @@ -27,7 +27,7 @@ #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/ops/anf_utils.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc index 5e61d4bf..9a78a2f2 100644 --- a/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/onnx_gelu_fusion.cc @@ -18,7 +18,7 @@ #include "tools/optimizer/fusion/onnx_gelu_fusion.h" #include #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" diff --git a/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc index 55ac223d..594023da 100644 --- a/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/prelu_fusion.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/fusion/prelu_fusion.h" #include "mindspore/ops/op_def/math_ops.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/prelu_fusion.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc b/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc index 38bffca6..f8ce59df 100644 --- a/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc @@ -16,7 +16,7 @@ #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_q.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc index 756b1aae..f2d73ef3 100644 --- a/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reduce_stack_fusion.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc b/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc index d17bf854..bc97c6e4 100644 --- a/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc +++ b/mindspore-lite/tools/optimizer/fusion/remove_transitivity_op.cc @@ -23,7 +23,7 @@ #include "tools/optimizer/fusion/strided_slice_checker.h" #include "tools/optimizer/common/gllo_utils.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc b/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc index 6045c9fb..b12a5634 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_like_operator_ablation.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/errorcode.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc index 370a5b86..4d909670 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_reduce_fusion.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/op_name.h" #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc index c525934b..7cf375e6 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_reshape_fusion.cc @@ -25,7 +25,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc index 5b2af293..ff949dc1 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_shape_fusion.cc @@ -19,7 +19,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc b/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc index a2307937..87e4c45d 100644 --- a/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/reshape_transpose_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/lite_exporter/fetch_content.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc b/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc index e3277369..1755f1e7 100644 --- a/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/resize_fusion.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/resize.h" #include "mindapi/base/types.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc index 91d5c74f..543cdfc5 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_activation_fusion.cc @@ -22,7 +22,7 @@ #include "infer/cxx_api/scale_fusion.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc index b0aaaf23..214424bf 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_base_fusion.cc @@ -21,7 +21,7 @@ #include "tools/common/tensor_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/quant_param_holder.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc b/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc index d68ace82..ff537a4c 100644 --- a/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/scale_scale_fusion.cc @@ -24,7 +24,7 @@ #include "tools/common/tensor_util.h" #include "infer/cxx_api/scale_fusion.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc index 9e058ea8..34c15542 100644 --- a/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/sigmoid_mul_fusion.cc @@ -22,7 +22,7 @@ #include "ops_utils/op_utils.h" #include "include/common/utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc b/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc index fa947d7a..74e9632b 100644 --- a/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/squeeze_expanddims_fusion.cc @@ -22,7 +22,7 @@ #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc b/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc index fd3a7728..1eb956e7 100644 --- a/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/squeeze_fusion.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "infer/unsqueeze.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" diff --git a/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc b/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc index e575974c..d4df9372 100644 --- a/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/strided_slice_fusion.cc @@ -23,7 +23,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "ir/func_graph.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/op_name.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc index a9c7d866..fabb55b2 100644 --- a/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tensor_dot_fusion.cc @@ -24,7 +24,7 @@ #include "ops_utils/op_utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 8e845de1..f7ed4cc5 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -32,7 +32,7 @@ #include "src/common/utils.h" #include "tools/common/tensor_util.h" #include "include/common/utils/utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc index c78153f7..91ed7586 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_gelu_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindapi/base/types.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc index 1143a737..34ef3c0b 100644 --- a/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tf_lstm_cell_fusion.cc @@ -27,7 +27,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/common/helper.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc index 588558e3..c987f445 100644 --- a/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -32,7 +32,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/helper.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc index 59e2ff67..fd9f9098 100644 --- a/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc @@ -25,7 +25,7 @@ #include "tools/converter/quantizer/quant_param_holder.h" #include "tools/converter/quantizer/quantize_util.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc index 3a9ad0a5..cd77724f 100644 --- a/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/tile_matmul_fusion.cc @@ -19,7 +19,7 @@ #include #include "mindspore/ops/op_def/lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc index 850d4227..aa282764 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_fusion.cc @@ -26,7 +26,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/format_utils.h" #include "infer/cxx_api/scale_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc index fab33d73..f68648c5 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_gather_fusion.cc @@ -22,7 +22,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc index 8064276d..60e61d5a 100644 --- a/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/transpose_matmul_fusion.cc @@ -22,7 +22,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc b/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc index 70ada19c..e6a40b07 100644 --- a/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc +++ b/mindspore-lite/tools/optimizer/graph/add_tensor_array.cc @@ -26,7 +26,7 @@ #include "infer/tensor_array.h" #include "infer/tensor_array_read.h" #include "infer/tensor_array_write.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/make_tuple.h" #include "infer/return.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" diff --git a/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc b/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc index e5704367..fd6679f4 100644 --- a/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/attr_to_args_pass.cc @@ -17,7 +17,7 @@ #include "tools/optimizer/graph/attr_to_args_pass.h" #include #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops/primitive_c.h" #include "ops/base_operator.h" diff --git a/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc b/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc index ce3382f3..f818d0ae 100644 --- a/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc +++ b/mindspore-lite/tools/optimizer/graph/broadcast_for_select.cc @@ -20,7 +20,7 @@ #include #include "mindspore/ops/op_def/array_ops.h" #include "tools/optimizer/common/gllo_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/anfalgo.h" #include "mindspore/ccsrc/include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" @@ -45,8 +45,9 @@ ShapeVector CalcBroadcastShape(AnfNodePtr cond, AnfNodePtr x, AnfNodePtr y) { auto cond_size = cond_shape.size(); auto x_size = x_shape.size(); auto y_size = y_shape.size(); - ShapeVector broadcast_shape = - cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape : x_size > y_size ? x_shape : y_shape; + ShapeVector broadcast_shape = cond_size > x_size ? cond_size > y_size ? cond_shape : y_shape + : x_size > y_size ? x_shape + : y_shape; auto n = broadcast_shape.size(); for (size_t i = n; i > 0; --i) { auto cond_i = cond_size < i ? 1 : cond_shape[cond_size - i]; diff --git a/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc b/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc index e7e25ad1..970d9df1 100644 --- a/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc index 13138233..af4ed935 100644 --- a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc @@ -28,7 +28,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/common/log_adapter.h" #include "tools/common/node_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/registry/converter_context.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc b/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc index 0aa5756a..ffa5b0c2 100644 --- a/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/core_infershape_pass.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "abstract/ops/primitive_infer_map.h" #include "tools/lite_exporter/fetch_content.h" diff --git a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc index 548f84c2..316be5ac 100644 --- a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -28,7 +28,7 @@ #include "src/common/common.h" #include "src/common/utils.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "tools/optimizer/graph/specify_graph_input_format.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc index d345d546..e90cb51e 100644 --- a/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -26,7 +26,7 @@ #include "src/common/log_adapter.h" #include "tools/common/tensor_util.h" #include "include/securec.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc index 81ed965a..6ae3dcd6 100644 --- a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc @@ -21,7 +21,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/graph/decrease_transpose_algo.h" diff --git a/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc b/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc index 47b53c66..c8b686ab 100644 --- a/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/input_data_type_trans_pass.cc @@ -25,7 +25,7 @@ #include "tools/lite_exporter/fetch_content.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" diff --git a/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc b/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc index 1f8008b4..dfd78381 100644 --- a/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/int64_cast_int32_pass.cc @@ -27,7 +27,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "src/tensor.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" diff --git a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc index b86ae3b1..27409afd 100644 --- a/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc +++ b/mindspore-lite/tools/optimizer/graph/lite_tensor_extractor.cc @@ -24,7 +24,7 @@ #include "src/tensorlist.h" #include "tools/optimizer/common/format_utils.h" #include "utils/ms_utils_secure.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc index 155cca61..fc1b9c8c 100644 --- a/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/miniaturization_pass.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/fill.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc b/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc index 2971ae29..e912b52c 100644 --- a/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/mul_constant_pass.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/graph/mul_constant_pass.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/cxx_api/mul_fusion.h" #include "src/common/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/node_infershape.cc b/mindspore-lite/tools/optimizer/graph/node_infershape.cc index 43222ffa..3aaa3224 100644 --- a/mindspore-lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore-lite/tools/optimizer/graph/node_infershape.cc @@ -35,7 +35,7 @@ #include "src/tensorlist.h" #include "src/registry/kernel_interface_registry.h" #include "tools/optimizer/graph/lite_tensor_extractor.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "tools/optimizer/format/to_nchw_format.h" #include "tools/optimizer/format/to_nhwc_format.h" diff --git a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc index e55755fd..485b1eaf 100644 --- a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc +++ b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc @@ -30,7 +30,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "mindspore/ops/op_def/op_name.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" diff --git a/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 74fe203a..7ece481d 100644 --- a/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -31,7 +31,7 @@ #include "infer/depend.h" #include "infer/cxx_api/pad_fusion.h" #include "ops_utils/op_utils.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "include/common/utils/utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" diff --git a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc index ccc49d4c..e0beb62d 100644 --- a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -32,7 +32,7 @@ #include "tools/optimizer/common/helper.h" #include "include/backend/optimizer/helper.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" diff --git a/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc b/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc index 33dce353..87dd2fa1 100644 --- a/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc +++ b/mindspore-lite/tools/optimizer/graph/special_node_postprocess.cc @@ -24,7 +24,7 @@ #include "mindspore/ops/op_def/framework_ops.h" #include "include/errorcode.h" #include "tools/optimizer/common/format_utils.h" -#include "nnacl//op_base.h" +#include "nnacl_c//op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" diff --git a/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc b/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc index 22cbae3a..5b73b8bf 100644 --- a/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc +++ b/mindspore-lite/tools/optimizer/graph/specify_graph_input_format.cc @@ -25,7 +25,7 @@ #include "tools/converter/parser/parser_utils.h" #include "tools/optimizer/common/format_utils.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" diff --git a/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc b/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc index 72ab75df..18debeb7 100644 --- a/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc +++ b/mindspore-lite/tools/optimizer/graph/specify_graph_output_format.cc @@ -26,7 +26,7 @@ #include "tools/optimizer/common/format_utils.h" #include "tools/lite_exporter/fetch_content.h" #include "src/common/log_adapter.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "infer/make_tuple.h" #include "mindspore/ccsrc/include/common/utils/utils.h" diff --git a/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc index 2ab1c639..afccb911 100644 --- a/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore-lite/tools/optimizer/graph/transpose_strategy.cc @@ -33,7 +33,7 @@ #include "infer/cxx_api/slice_fusion.h" #include "ops_utils/op_utils.h" #include "tools/lite_exporter/fetch_content.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" diff --git a/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc b/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc index ab455c66..ac4a402b 100644 --- a/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc @@ -23,7 +23,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" diff --git a/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc index 1ffbd0d7..06daaae8 100644 --- a/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/conv2d_info.cc @@ -30,7 +30,7 @@ #include "tools/optimizer/parallel/operator_info_register.h" #include "tools/optimizer/parallel/spliter.h" #include "tools/optimizer/fisson/fisson_util.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "utils/anf_utils.h" diff --git a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc index dba686ed..7cb15ba0 100644 --- a/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/depthwise_conv2d_info.cc @@ -24,7 +24,7 @@ #include "include/securec.h" #include "mindspore/ops/op_def/conv_pool_ops.h" #include "mindspore/ops/op_def/lite_ops.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/common/gllo_utils.h" #include "include/common/utils/utils.h" diff --git a/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc b/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc index 073aa215..d1435844 100644 --- a/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/multi_conv_info.cc @@ -21,7 +21,7 @@ #include "tools/optimizer/parallel/spliter.h" #include "infer/cxx_api/conv2d_fusion.h" #include "tools/optimizer/parallel/split_strategy.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" using mindspore::schema::PrimitiveType_Conv2dTransposeFusion; diff --git a/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc b/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc index b9e86cdc..3cb6fae1 100644 --- a/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc +++ b/mindspore-lite/tools/optimizer/parallel/multi_node_split.cc @@ -17,7 +17,7 @@ #define USE_DEPRECATED_API #include "tools/optimizer/parallel/multi_node_split.h" #include "tools/optimizer/parallel/multi_conv_info.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" namespace mindspore { namespace opt { diff --git a/mindspore-lite/tools/optimizer/parallel/operator_info.cc b/mindspore-lite/tools/optimizer/parallel/operator_info.cc index 30631f39..8075b7ce 100644 --- a/mindspore-lite/tools/optimizer/parallel/operator_info.cc +++ b/mindspore-lite/tools/optimizer/parallel/operator_info.cc @@ -22,7 +22,7 @@ #include "infer/tuple_get_item.h" #include "include/common/utils/utils.h" #include "include/errorcode.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" #include "src/common/log_util.h" diff --git a/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc b/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc index 99f30154..4dd3e50d 100644 --- a/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc +++ b/mindspore-lite/tools/optimizer/parallel/parallel_pass.cc @@ -20,7 +20,7 @@ #include "ir/tensor.h" #include "tools/optimizer/parallel/operator_info_register.h" #include "infer/cxx_api/conv2d_fusion.h" -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "ops_utils/op_utils.h" namespace mindspore { diff --git a/mindspore-lite/tools/optimizer/parallel/split_strategy.cc b/mindspore-lite/tools/optimizer/parallel/split_strategy.cc index 82c45d01..e76b036a 100644 --- a/mindspore-lite/tools/optimizer/parallel/split_strategy.cc +++ b/mindspore-lite/tools/optimizer/parallel/split_strategy.cc @@ -19,7 +19,7 @@ #include #include #include -#include "nnacl/op_base.h" +#include "nnacl_c/op_base.h" #include "src/common/log_util.h" namespace mindspore { -- Gitee From 25a7d8401f6d2fc7b826d84150ee5a3c3d792bf4 Mon Sep 17 00:00:00 2001 From: liuf9 Date: Fri, 1 Aug 2025 15:47:22 +0800 Subject: [PATCH 4/7] update --- build.bat | 2 ++ .../0001-adjust-device-address.patch | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 third_party/patch/mindspore/0001-adjust-device-address.patch diff --git a/build.bat b/build.bat index 1405ccd7..e7e511b3 100644 --- a/build.bat +++ b/build.bat @@ -34,6 +34,8 @@ set ENABLE_FFMPEG=ON set ENABLE_FFMPEG_DOWNLOAD=OFF for /f "tokens=1" %%a in (version.txt) do (set VERSION_STR=%%a) git submodule update --init --remote mindspore +pushd "%BASE_PATH%/mindspore" && git apply "%BASE_PATH%/third_party/patch/mindspore/0001-adjust-device-address.patch" --whitespace=fix && popd + ECHO %2%|FINDSTR "^[0-9][0-9]*$" IF %errorlevel% == 0 ( SET threads=%2% diff --git a/third_party/patch/mindspore/0001-adjust-device-address.patch b/third_party/patch/mindspore/0001-adjust-device-address.patch new file mode 100644 index 00000000..dd57c2d1 --- /dev/null +++ b/third_party/patch/mindspore/0001-adjust-device-address.patch @@ -0,0 +1,24 @@ +From 33944796dfb00d3798a31cd0351e92fe6fb9ec0f Mon Sep 17 00:00:00 2001 +From: liuf9 +Date: Fri, 1 Aug 2025 15:39:19 +0800 +Subject: [PATCH] adjust device address + +--- + mindspore/ops/kernel/common/device_address.cc | 1 - + 1 file changed, 1 deletion(-) + +diff --git a/mindspore/ops/kernel/common/device_address.cc b/mindspore/ops/kernel/common/device_address.cc +index 15149f63bcb..a3edadf44f7 100644 +--- a/mindspore/ops/kernel/common/device_address.cc ++++ b/mindspore/ops/kernel/common/device_address.cc +@@ -16,7 +16,6 @@ + + #include "common/device_address.h" + #include "common/format_utils.h" +-#include "runtime/device/res_manager/hal_res_base.h" + + namespace mindspore { + namespace device { +-- +2.39.1.windows.1 + -- Gitee From 946e06e6c064b2b40674ab76e2999a09d12b0572 Mon Sep 17 00:00:00 2001 From: liuf9 Date: Fri, 1 Aug 2025 23:46:18 +0800 Subject: [PATCH 5/7] update2 --- mindspore-lite/cmake/ccsrc_extendrt.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore-lite/cmake/ccsrc_extendrt.cmake b/mindspore-lite/cmake/ccsrc_extendrt.cmake index a0fbf76a..c671bfa8 100644 --- a/mindspore-lite/cmake/ccsrc_extendrt.cmake +++ b/mindspore-lite/cmake/ccsrc_extendrt.cmake @@ -58,8 +58,8 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) ${CMAKE_CURRENT_SOURCE_DIR}/mock/segment_runner.cc ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc ${CCSRC_DIR}/kernel/kernel_info.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime.cc + # ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc + # ${CCSRC_DIR}/runtime/device/kernel_runtime.cc ${CCSRC_DIR}/runtime/device/memory_scheduler.cc ${CCSRC_DIR}/runtime/device/memory_offload_strategy.cc ${CCSRC_DIR}/runtime/device/res_manager/memory_manager.cc -- Gitee From 023aab46668c06003dca5517d075c752e472ef5c Mon Sep 17 00:00:00 2001 From: liuf9 Date: Sat, 2 Aug 2025 00:05:22 +0800 Subject: [PATCH 6/7] update --- mindspore-lite/cmake/ccsrc_converter.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore-lite/cmake/ccsrc_converter.cmake b/mindspore-lite/cmake/ccsrc_converter.cmake index abddb41c..68340002 100644 --- a/mindspore-lite/cmake/ccsrc_converter.cmake +++ b/mindspore-lite/cmake/ccsrc_converter.cmake @@ -43,7 +43,7 @@ if(MSLITE_ENABLE_CONVERTER) ${OPS_DIR}/kernel/common/oplib/oplib.cc ${CCSRC_DIR}/kernel/kernel_info.cc ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc + # ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc ${CCSRC_DIR}/runtime/hardware/device_context_manager.cc ${CCSRC_DIR}/common/runtime_conf/runtime_conf.cc ${CCSRC_DIR}/utils/comm_manager.cc -- Gitee From 72ee31817717d24836f6918fdea9ffab58e2dddb Mon Sep 17 00:00:00 2001 From: liuf9 Date: Sat, 2 Aug 2025 09:24:14 +0800 Subject: [PATCH 7/7] update --- mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc | 2 +- .../adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h | 2 +- .../acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc | 2 +- .../tools/graph_kernel/converter/preprocess_weight.cc | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc index d9233c4a..7ad15a2e 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc @@ -48,7 +48,7 @@ #include "include/api/model_group.h" #include "src/common/common.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { namespace { const char *const kExecutionPlan = "execution_plan"; diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h index 652e00f4..f50766ae 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h @@ -30,7 +30,7 @@ #include "backend/ms_backend/ms_backend.h" #include "backend/backend_manager/backend_jit_config.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { class GraphCell::GraphImpl { public: diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc index 78475737..7cbb70e1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc @@ -28,7 +28,7 @@ #include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" #include "cxx_api/model/acl/acl_vm/acl_vm.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore { API_MODEL_REG(Ascend310, AclModelMulti); diff --git a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc index e745e31d..51863d8e 100644 --- a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc +++ b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc @@ -21,7 +21,7 @@ #include "utils/anf_utils.h" #include "backend/common/graph_kernel/core/graph_kernel_callback.h" #include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "ir/tensor_api.h" +#include "ir/tensor_new.h" namespace mindspore::graphkernel { constexpr size_t kConv2dDataIndex = 1; -- Gitee